Refactor openssl_privatekey module, move add openssl_privatekey_pipe module (#119)

* Move disk-independent parts of openssl_privatekey to module_utils and doc_fragments.

* Improve documentation.

* Add openssl_privatekey_pipe module.

* Fallback in case no fingerprints are returned.

* Prevent no_log=True for content to stop module from working correctly.

* Forgot version_added.

* Update copyright. All the interesting code is no longer in this file anyway.

* Remove file arguments.

* Add framework for action modules.

* Convert openssl_privatekey_pipe to action plugin.

* Linting.

* Bump version.

* Add return_current_key option.

* Add no_log to examples.

* Remove preparation for potential later extensibility (easy to re-add when needed).

* Fix deprecation version in docs.

* Use new ArgumentSpec object for AnsibleActionModule as well.
This commit is contained in:
Felix Fontein
2020-10-28 21:52:54 +01:00
committed by GitHub
parent 9792188b0e
commit 3c21079afa
16 changed files with 1945 additions and 740 deletions

View File

@@ -0,0 +1,734 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2013 Michael DeHaan <michael.dehaan@gmail.com>
# Copyright (c) 2016 Toshio Kuratomi <tkuratomi@ansible.com>
# Copyright (c) 2019 Ansible Project
# Copyright (c) 2020 Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
# Parts taken from ansible.module_utils.basic and ansible.module_utils.common.warnings.
# NOTE: THIS MUST NOT BE USED BY A MODULE! THIS IS ONLY FOR ACTION PLUGINS!
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import abc
import copy
import traceback
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.module_utils import six
from ansible.module_utils.basic import AnsibleFallbackNotFound, SEQUENCETYPE, remove_values
from ansible.module_utils.common._collections_compat import (
Mapping
)
from ansible.module_utils.common.parameters import (
handle_aliases,
list_deprecations,
list_no_log_values,
PASS_VARS,
PASS_BOOLS,
)
from ansible.module_utils.common.validation import (
check_mutually_exclusive,
check_required_arguments,
check_required_by,
check_required_if,
check_required_one_of,
check_required_together,
count_terms,
check_type_bool,
check_type_bits,
check_type_bytes,
check_type_float,
check_type_int,
check_type_jsonarg,
check_type_list,
check_type_dict,
check_type_path,
check_type_raw,
check_type_str,
safe_eval,
)
from ansible.module_utils.common.text.formatters import (
lenient_lowercase,
)
from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE
from ansible.module_utils.six import (
binary_type,
string_types,
text_type,
)
from ansible.module_utils._text import to_native, to_text
from ansible.plugins.action import ActionBase
class _ModuleExitException(Exception):
def __init__(self, result):
super(_ModuleExitException, self).__init__()
self.result = result
class AnsibleActionModule(object):
def __init__(self, action_plugin, argument_spec, bypass_checks=False,
mutually_exclusive=None, required_together=None,
required_one_of=None, supports_check_mode=False,
required_if=None, required_by=None):
# Internal data
self.__action_plugin = action_plugin
self.__warnings = []
self.__deprecations = []
# AnsibleModule data
self._name = self.__action_plugin._task.action
self.argument_spec = argument_spec
self.supports_check_mode = supports_check_mode
self.check_mode = self.__action_plugin._play_context.check_mode
self.bypass_checks = bypass_checks
self.no_log = self.__action_plugin._play_context.no_log
self.mutually_exclusive = mutually_exclusive
self.required_together = required_together
self.required_one_of = required_one_of
self.required_if = required_if
self.required_by = required_by
self._diff = self.__action_plugin._play_context.diff
self._verbosity = self.__action_plugin._display.verbosity
self._string_conversion_action = C.STRING_CONVERSION_ACTION
self.aliases = {}
self._legal_inputs = []
self.params = copy.deepcopy(action_plugin._task.args)
self._set_fallbacks()
# append to legal_inputs and then possibly check against them
try:
self.aliases = self._handle_aliases()
except (ValueError, TypeError) as e:
# Use exceptions here because it isn't safe to call fail_json until no_log is processed
raise _ModuleExitException(dict(failed=True, msg="Module alias error: %s" % to_native(e)))
# Save parameter values that should never be logged
self.no_log_values = set()
self._handle_no_log_values()
self._check_arguments()
# check exclusive early
if not bypass_checks:
self._check_mutually_exclusive(mutually_exclusive)
self._set_defaults(pre=True)
self._CHECK_ARGUMENT_TYPES_DISPATCHER = {
'str': self._check_type_str,
'list': self._check_type_list,
'dict': self._check_type_dict,
'bool': self._check_type_bool,
'int': self._check_type_int,
'float': self._check_type_float,
'path': self._check_type_path,
'raw': self._check_type_raw,
'jsonarg': self._check_type_jsonarg,
'json': self._check_type_jsonarg,
'bytes': self._check_type_bytes,
'bits': self._check_type_bits,
}
if not bypass_checks:
self._check_required_arguments()
self._check_argument_types()
self._check_argument_values()
self._check_required_together(required_together)
self._check_required_one_of(required_one_of)
self._check_required_if(required_if)
self._check_required_by(required_by)
self._set_defaults(pre=False)
# deal with options sub-spec
self._handle_options()
def _handle_aliases(self, spec=None, param=None, option_prefix=''):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
# this uses exceptions as it happens before we can safely call fail_json
alias_warnings = []
alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings=alias_warnings)
for option, alias in alias_warnings:
self.warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias))
deprecated_aliases = []
for i in spec.keys():
if 'deprecated_aliases' in spec[i].keys():
for alias in spec[i]['deprecated_aliases']:
deprecated_aliases.append(alias)
for deprecation in deprecated_aliases:
if deprecation['name'] in param.keys():
self.deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
return alias_results
def _handle_no_log_values(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
self.no_log_values.update(list_no_log_values(spec, param))
except TypeError as te:
self.fail_json(msg="Failure when processing no_log parameters. Module invocation will be hidden. "
"%s" % to_native(te), invocation={'module_args': 'HIDDEN DUE TO FAILURE'})
for message in list_deprecations(spec, param):
self.deprecate(message['msg'], version=message.get('version'), date=message.get('date'),
collection_name=message.get('collection_name'))
def _check_arguments(self, spec=None, param=None, legal_inputs=None):
self._syslog_facility = 'LOG_USER'
unsupported_parameters = set()
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
if legal_inputs is None:
legal_inputs = self._legal_inputs
for k in list(param.keys()):
if k not in legal_inputs:
unsupported_parameters.add(k)
for k in PASS_VARS:
# handle setting internal properties from internal ansible vars
param_key = '_ansible_%s' % k
if param_key in param:
if k in PASS_BOOLS:
setattr(self, PASS_VARS[k][0], self.boolean(param[param_key]))
else:
setattr(self, PASS_VARS[k][0], param[param_key])
# clean up internal top level params:
if param_key in self.params:
del self.params[param_key]
else:
# use defaults if not already set
if not hasattr(self, PASS_VARS[k][0]):
setattr(self, PASS_VARS[k][0], PASS_VARS[k][1])
if unsupported_parameters:
msg = "Unsupported parameters for (%s) module: %s" % (self._name, ', '.join(sorted(list(unsupported_parameters))))
if self._options_context:
msg += " found in %s." % " -> ".join(self._options_context)
supported_parameters = list()
for key in sorted(spec.keys()):
if 'aliases' in spec[key] and spec[key]['aliases']:
supported_parameters.append("%s (%s)" % (key, ', '.join(sorted(spec[key]['aliases']))))
else:
supported_parameters.append(key)
msg += " Supported parameters include: %s" % (', '.join(supported_parameters))
self.fail_json(msg=msg)
if self.check_mode and not self.supports_check_mode:
self.exit_json(skipped=True, msg="action module (%s) does not support check mode" % self._name)
def _count_terms(self, check, param=None):
if param is None:
param = self.params
return count_terms(check, param)
def _check_mutually_exclusive(self, spec, param=None):
if param is None:
param = self.params
try:
check_mutually_exclusive(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_one_of(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_one_of(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_together(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_together(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_by(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_by(spec, param)
except TypeError as e:
self.fail_json(msg=to_native(e))
def _check_required_arguments(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
check_required_arguments(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_if(self, spec, param=None):
''' ensure that parameters which conditionally required are present '''
if spec is None:
return
if param is None:
param = self.params
try:
check_required_if(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_argument_values(self, spec=None, param=None):
''' ensure all arguments have the requested values, and there are no stray arguments '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
choices = v.get('choices', None)
if choices is None:
continue
if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)):
if k in param:
# Allow one or more when type='list' param with choices
if isinstance(param[k], list):
diff_list = ", ".join([item for item in param[k] if item not in choices])
if diff_list:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one or more of: %s. Got no match for: %s" % (k, choices_str, diff_list)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
elif param[k] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible.
lowered_choices = None
if param[k] == 'False':
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_FALSE.intersection(choices)
if len(overlap) == 1:
# Extract from a set
(param[k],) = overlap
if param[k] == 'True':
if lowered_choices is None:
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_TRUE.intersection(choices)
if len(overlap) == 1:
(param[k],) = overlap
if param[k] not in choices:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one of: %s, got: %s" % (k, choices_str, param[k])
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
else:
msg = "internal error: choices for argument %s are not iterable: %s" % (k, choices)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False):
return safe_eval(value, locals, include_exceptions)
def _check_type_str(self, value, param=None, prefix=''):
opts = {
'error': False,
'warn': False,
'ignore': True
}
# Ignore, warn, or error when converting to a string.
allow_conversion = opts.get(self._string_conversion_action, True)
try:
return check_type_str(value, allow_conversion)
except TypeError:
common_msg = 'quote the entire value to ensure it does not change.'
from_msg = '{0!r}'.format(value)
to_msg = '{0!r}'.format(to_text(value))
if param is not None:
if prefix:
param = '{0}{1}'.format(prefix, param)
from_msg = '{0}: {1!r}'.format(param, value)
to_msg = '{0}: {1!r}'.format(param, to_text(value))
if self._string_conversion_action == 'error':
msg = common_msg.capitalize()
raise TypeError(to_native(msg))
elif self._string_conversion_action == 'warn':
msg = ('The value "{0}" (type {1.__class__.__name__}) was converted to "{2}" (type string). '
'If this does not look like what you expect, {3}').format(from_msg, value, to_msg, common_msg)
self.warn(to_native(msg))
return to_native(value, errors='surrogate_or_strict')
def _check_type_list(self, value):
return check_type_list(value)
def _check_type_dict(self, value):
return check_type_dict(value)
def _check_type_bool(self, value):
return check_type_bool(value)
def _check_type_int(self, value):
return check_type_int(value)
def _check_type_float(self, value):
return check_type_float(value)
def _check_type_path(self, value):
return check_type_path(value)
def _check_type_jsonarg(self, value):
return check_type_jsonarg(value)
def _check_type_raw(self, value):
return check_type_raw(value)
def _check_type_bytes(self, value):
return check_type_bytes(value)
def _check_type_bits(self, value):
return check_type_bits(value)
def _handle_options(self, argument_spec=None, params=None, prefix=''):
''' deal with options to create sub spec '''
if argument_spec is None:
argument_spec = self.argument_spec
if params is None:
params = self.params
for (k, v) in argument_spec.items():
wanted = v.get('type', None)
if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'):
spec = v.get('options', None)
if v.get('apply_defaults', False):
if spec is not None:
if params.get(k) is None:
params[k] = {}
else:
continue
elif spec is None or k not in params or params[k] is None:
continue
self._options_context.append(k)
if isinstance(params[k], dict):
elements = [params[k]]
else:
elements = params[k]
for idx, param in enumerate(elements):
if not isinstance(param, dict):
self.fail_json(msg="value of %s must be of type dict or list of dict" % k)
new_prefix = prefix + k
if wanted == 'list':
new_prefix += '[%d]' % idx
new_prefix += '.'
self._set_fallbacks(spec, param)
options_aliases = self._handle_aliases(spec, param, option_prefix=new_prefix)
options_legal_inputs = list(spec.keys()) + list(options_aliases.keys())
self._check_arguments(spec, param, options_legal_inputs)
# check exclusive early
if not self.bypass_checks:
self._check_mutually_exclusive(v.get('mutually_exclusive', None), param)
self._set_defaults(pre=True, spec=spec, param=param)
if not self.bypass_checks:
self._check_required_arguments(spec, param)
self._check_argument_types(spec, param, new_prefix)
self._check_argument_values(spec, param)
self._check_required_together(v.get('required_together', None), param)
self._check_required_one_of(v.get('required_one_of', None), param)
self._check_required_if(v.get('required_if', None), param)
self._check_required_by(v.get('required_by', None), param)
self._set_defaults(pre=False, spec=spec, param=param)
# handle multi level options (sub argspec)
self._handle_options(spec, param, new_prefix)
self._options_context.pop()
def _get_wanted_type(self, wanted, k):
if not callable(wanted):
if wanted is None:
# Mostly we want to default to str.
# For values set to None explicitly, return None instead as
# that allows a user to unset a parameter
wanted = 'str'
try:
type_checker = self._CHECK_ARGUMENT_TYPES_DISPATCHER[wanted]
except KeyError:
self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
else:
# set the type_checker to the callable, and reset wanted to the callable's name (or type if it doesn't have one, ala MagicMock)
type_checker = wanted
wanted = getattr(wanted, '__name__', to_native(type(wanted)))
return type_checker, wanted
def _handle_elements(self, wanted, param, values):
type_checker, wanted_name = self._get_wanted_type(wanted, param)
validated_params = []
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_name == 'str' and isinstance(wanted, string_types):
if isinstance(param, string_types):
kwargs['param'] = param
elif isinstance(param, dict):
kwargs['param'] = list(param.keys())[0]
for value in values:
try:
validated_params.append(type_checker(value, **kwargs))
except (TypeError, ValueError) as e:
msg = "Elements value for option %s" % param
if self._options_context:
msg += " found in '%s'" % " -> ".join(self._options_context)
msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_name, to_native(e))
self.fail_json(msg=msg)
return validated_params
def _check_argument_types(self, spec=None, param=None, prefix=''):
''' ensure all arguments have the requested type '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
wanted = v.get('type', None)
if k not in param:
continue
value = param[k]
if value is None:
continue
type_checker, wanted_name = self._get_wanted_type(wanted, k)
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_name == 'str' and isinstance(type_checker, string_types):
kwargs['param'] = list(param.keys())[0]
# Get the name of the parent key if this is a nested option
if prefix:
kwargs['prefix'] = prefix
try:
param[k] = type_checker(value, **kwargs)
wanted_elements = v.get('elements', None)
if wanted_elements:
if wanted != 'list' or not isinstance(param[k], list):
msg = "Invalid type %s for option '%s'" % (wanted_name, param)
if self._options_context:
msg += " found in '%s'." % " -> ".join(self._options_context)
msg += ", elements value check is supported only with 'list' type"
self.fail_json(msg=msg)
param[k] = self._handle_elements(wanted_elements, k, param[k])
except (TypeError, ValueError) as e:
msg = "argument %s is of type %s" % (k, type(value))
if self._options_context:
msg += " found in '%s'." % " -> ".join(self._options_context)
msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e))
self.fail_json(msg=msg)
def _set_defaults(self, pre=True, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
default = v.get('default', None)
if pre is True:
# this prevents setting defaults on required items
if default is not None and k not in param:
param[k] = default
else:
# make sure things without a default still get set None
if k not in param:
param[k] = default
def _set_fallbacks(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
fallback = v.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if k not in param and fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
param[k] = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
def warn(self, warning):
# Copied from ansible.module_utils.common.warnings:
if isinstance(warning, string_types):
self.__warnings.append(warning)
else:
raise TypeError("warn requires a string not a %s" % type(warning))
def deprecate(self, msg, version=None, date=None, collection_name=None):
if version is not None and date is not None:
raise AssertionError("implementation error -- version and date must not both be set")
# Copied from ansible.module_utils.common.warnings:
if isinstance(msg, string_types):
# For compatibility, we accept that neither version nor date is set,
# and treat that the same as if version would haven been set
if date is not None:
self.__deprecations.append({'msg': msg, 'date': date, 'collection_name': collection_name})
else:
self.__deprecations.append({'msg': msg, 'version': version, 'collection_name': collection_name})
else:
raise TypeError("deprecate requires a string not a %s" % type(msg))
def _return_formatted(self, kwargs):
if 'invocation' not in kwargs:
kwargs['invocation'] = {'module_args': self.params}
if 'warnings' in kwargs:
if isinstance(kwargs['warnings'], list):
for w in kwargs['warnings']:
self.warn(w)
else:
self.warn(kwargs['warnings'])
if self.__warnings:
kwargs['warnings'] = self.__warnings
if 'deprecations' in kwargs:
if isinstance(kwargs['deprecations'], list):
for d in kwargs['deprecations']:
if isinstance(d, SEQUENCETYPE) and len(d) == 2:
self.deprecate(d[0], version=d[1])
elif isinstance(d, Mapping):
self.deprecate(d['msg'], version=d.get('version'), date=d.get('date'),
collection_name=d.get('collection_name'))
else:
self.deprecate(d) # pylint: disable=ansible-deprecated-no-version
else:
self.deprecate(kwargs['deprecations']) # pylint: disable=ansible-deprecated-no-version
if self.__deprecations:
kwargs['deprecations'] = self.__deprecations
kwargs = remove_values(kwargs, self.no_log_values)
raise _ModuleExitException(kwargs)
def exit_json(self, **kwargs):
result = dict(kwargs)
if 'failed' not in result:
result['failed'] = False
self._return_formatted(result)
def fail_json(self, msg, **kwargs):
result = dict(kwargs)
result['failed'] = True
result['msg'] = msg
self._return_formatted(result)
@six.add_metaclass(abc.ABCMeta)
class ActionModuleBase(ActionBase):
@abc.abstractmethod
def setup_module(self):
"""Return pair (ArgumentSpec, kwargs)."""
pass
@abc.abstractmethod
def run_module(self, module):
"""Run module code"""
module.fail_json(msg='Not implemented.')
def run(self, tmp=None, task_vars=None):
if task_vars is None:
task_vars = dict()
result = super(ActionModuleBase, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
try:
argument_spec, kwargs = self.setup_module()
module = argument_spec.create_ansible_module_helper(AnsibleActionModule, (self, ), **kwargs)
self.run_module(module)
raise AnsibleError('Internal error: action module did not call module.exit_json()')
except _ModuleExitException as mee:
result.update(mee.result)
return result
except Exception as dummy:
result['failed'] = True
result['msg'] = 'MODULE FAILURE'
result['exception'] = traceback.format_exc()
return result

View File

@@ -19,8 +19,9 @@ class ArgumentSpec:
self.required_if = required_if or []
self.required_by = required_by or {}
def create_ansible_module(self, **kwargs):
return AnsibleModule(
def create_ansible_module_helper(self, clazz, args, **kwargs):
return clazz(
*args,
argument_spec=self.argument_spec,
mutually_exclusive=self.mutually_exclusive,
required_together=self.required_together,
@@ -28,3 +29,6 @@ class ArgumentSpec:
required_if=self.required_if,
required_by=self.required_by,
**kwargs)
def create_ansible_module(self, **kwargs):
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)

View File

@@ -0,0 +1,590 @@
# -*- coding: utf-8 -*-
#
# Copyright: (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright: (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import abc
import base64
import traceback
from distutils.version import LooseVersion
from ansible.module_utils import six
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils._text import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
CRYPTOGRAPHY_HAS_X25519,
CRYPTOGRAPHY_HAS_X25519_FULL,
CRYPTOGRAPHY_HAS_X448,
CRYPTOGRAPHY_HAS_ED25519,
CRYPTOGRAPHY_HAS_ED448,
OpenSSLObjectError,
OpenSSLBadPassphraseError,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_privatekey,
get_fingerprint_of_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.identify import (
identify_private_key_format,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.common import ArgumentSpec
MINIMAL_PYOPENSSL_VERSION = '0.6'
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2.3'
PYOPENSSL_IMP_ERR = None
try:
import OpenSSL
from OpenSSL import crypto
PYOPENSSL_VERSION = LooseVersion(OpenSSL.__version__)
except ImportError:
PYOPENSSL_IMP_ERR = traceback.format_exc()
PYOPENSSL_FOUND = False
else:
PYOPENSSL_FOUND = True
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
import cryptography.exceptions
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.serialization
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.dsa
import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.utils
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
CRYPTOGRAPHY_FOUND = False
else:
CRYPTOGRAPHY_FOUND = True
class PrivateKeyError(OpenSSLObjectError):
pass
# From the object called `module`, only the following properties are used:
#
# - module.params[]
# - module.warn(msg: str)
# - module.fail_json(msg: str, **kwargs)
@six.add_metaclass(abc.ABCMeta)
class PrivateKeyBackend:
def __init__(self, module, backend):
self.module = module
self.type = module.params['type']
self.size = module.params['size']
self.curve = module.params['curve']
self.passphrase = module.params['passphrase']
self.cipher = module.params['cipher']
self.format = module.params['format']
self.format_mismatch = module.params.get('format_mismatch', 'regenerate')
self.regenerate = module.params.get('regenerate', 'full_idempotence')
self.backend = backend
self.private_key = None
self.existing_private_key = None
self.existing_private_key_bytes = None
@abc.abstractmethod
def generate_private_key(self):
"""(Re-)Generate private key."""
pass
def convert_private_key(self):
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
This is effectively a copy without active conversion. The conversion is done
during load and store; get_private_key_data() uses the destination format to
serialize the key.
"""
self._ensure_existing_private_key_loaded()
self.private_key = self.existing_private_key
@abc.abstractmethod
def get_private_key_data(self):
"""Return bytes for self.private_key."""
pass
def set_existing(self, privatekey_bytes):
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
def has_existing(self):
"""Query whether an existing private key is/has been there."""
return self.existing_private_key_bytes is not None
@abc.abstractmethod
def _check_passphrase(self):
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
pass
@abc.abstractmethod
def _ensure_existing_private_key_loaded(self):
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
pass
@abc.abstractmethod
def _check_size_and_type(self):
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
pass
@abc.abstractmethod
def _check_format(self):
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
pass
def needs_regeneration(self):
"""Check whether a regeneration is necessary."""
if self.regenerate == 'always':
return True
if not self.has_existing():
# key does not exist
return True
if not self._check_passphrase():
if self.regenerate == 'full_idempotence':
return True
self.module.fail_json(msg='Unable to read the key. The key is protected with a another passphrase / no passphrase or broken.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `full_idempotence` or `always`, or with `force=yes`.')
self._ensure_existing_private_key_loaded()
if self.regenerate != 'never':
if not self._check_size_and_type():
if self.regenerate in ('partial_idempotence', 'full_idempotence'):
return True
self.module.fail_json(msg='Key has wrong type and/or size.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=yes`.')
# During generation step, regenerate if format does not match and format_mismatch == 'regenerate'
if self.format_mismatch == 'regenerate' and self.regenerate != 'never':
if not self._check_format():
if self.regenerate in ('partial_idempotence', 'full_idempotence'):
return True
self.module.fail_json(msg='Key has wrong format.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=yes`.'
' To convert the key, set `format_mismatch` to `convert`.')
return False
def needs_conversion(self):
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
# During conversion step, convert if format does not match and format_mismatch == 'convert'
self._ensure_existing_private_key_loaded()
return self.has_existing() and self.format_mismatch == 'convert' and not self._check_format()
def _get_fingerprint(self):
if self.private_key:
return get_fingerprint_of_privatekey(self.private_key, backend=self.backend)
try:
self._ensure_existing_private_key_loaded()
except Exception as dummy:
# Ignore errors
pass
if self.existing_private_key:
return get_fingerprint_of_privatekey(self.existing_private_key, backend=self.backend)
def dump(self, include_key):
"""Serialize the object into a dictionary."""
if not self.private_key:
try:
self._ensure_existing_private_key_loaded()
except Exception as dummy:
# Ignore errors
pass
result = {
'type': self.type,
'size': self.size,
'fingerprint': self._get_fingerprint(),
}
if self.type == 'ECC':
result['curve'] = self.curve
if include_key:
# Get hold of private key bytes
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
pk_bytes = self.get_private_key_data()
# Store result
if pk_bytes:
if identify_private_key_format(pk_bytes) == 'raw':
result['privatekey'] = base64.b64encode(pk_bytes)
else:
result['privatekey'] = pk_bytes.decode('utf-8')
else:
result['privatekey'] = None
return result
# Implementation with using pyOpenSSL
class PrivateKeyPyOpenSSLBackend(PrivateKeyBackend):
def __init__(self, module):
super(PrivateKeyPyOpenSSLBackend, self).__init__(module=module, backend='pyopenssl')
if self.type == 'RSA':
self.openssl_type = crypto.TYPE_RSA
elif self.type == 'DSA':
self.openssl_type = crypto.TYPE_DSA
else:
self.module.fail_json(msg="PyOpenSSL backend only supports RSA and DSA keys.")
if self.format != 'auto_ignore':
self.module.fail_json(msg="PyOpenSSL backend only supports auto_ignore format.")
def generate_private_key(self):
"""(Re-)Generate private key."""
self.private_key = crypto.PKey()
try:
self.private_key.generate_key(self.openssl_type, self.size)
except (TypeError, ValueError) as exc:
raise PrivateKeyError(exc)
def _ensure_existing_private_key_loaded(self):
if self.existing_private_key is None and self.has_existing():
try:
self.existing_private_key = load_privatekey(
None, self.passphrase, content=self.existing_private_key_bytes, backend=self.backend)
except OpenSSLBadPassphraseError as exc:
raise PrivateKeyError(exc)
def get_private_key_data(self):
"""Return bytes for self.private_key"""
if self.cipher and self.passphrase:
return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.private_key,
self.cipher, to_bytes(self.passphrase))
else:
return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.private_key)
def _check_passphrase(self):
try:
load_privatekey(None, self.passphrase, content=self.existing_private_key_bytes, backend=self.backend)
return True
except Exception as dummy:
return False
def _check_size_and_type(self):
return self.size == self.existing_private_key.bits() and self.openssl_type == self.existing_private_key.type()
def _check_format(self):
# Not supported by this backend
return True
# Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _get_ec_class(self, ectype):
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype)
if ecclass is None:
self.module.fail_json(msg='Your cryptography version does not support {0}'.format(ectype))
return ecclass
def _add_curve(self, name, ectype, deprecated=False):
def create(size):
ecclass = self._get_ec_class(ectype)
return ecclass()
def verify(privatekey):
ecclass = self._get_ec_class(ectype)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
self.curves[name] = {
'create': create,
'verify': verify,
'deprecated': deprecated,
}
def __init__(self, module):
super(PrivateKeyCryptographyBackend, self).__init__(module=module, backend='cryptography')
self.curves = dict()
self._add_curve('secp384r1', 'SECP384R1')
self._add_curve('secp521r1', 'SECP521R1')
self._add_curve('secp224r1', 'SECP224R1')
self._add_curve('secp192r1', 'SECP192R1')
self._add_curve('secp256r1', 'SECP256R1')
self._add_curve('secp256k1', 'SECP256K1')
self._add_curve('brainpoolP256r1', 'BrainpoolP256R1', deprecated=True)
self._add_curve('brainpoolP384r1', 'BrainpoolP384R1', deprecated=True)
self._add_curve('brainpoolP512r1', 'BrainpoolP512R1', deprecated=True)
self._add_curve('sect571k1', 'SECT571K1', deprecated=True)
self._add_curve('sect409k1', 'SECT409K1', deprecated=True)
self._add_curve('sect283k1', 'SECT283K1', deprecated=True)
self._add_curve('sect233k1', 'SECT233K1', deprecated=True)
self._add_curve('sect163k1', 'SECT163K1', deprecated=True)
self._add_curve('sect571r1', 'SECT571R1', deprecated=True)
self._add_curve('sect409r1', 'SECT409R1', deprecated=True)
self._add_curve('sect283r1', 'SECT283R1', deprecated=True)
self._add_curve('sect233r1', 'SECT233R1', deprecated=True)
self._add_curve('sect163r2', 'SECT163R2', deprecated=True)
self.cryptography_backend = cryptography.hazmat.backends.default_backend()
if not CRYPTOGRAPHY_HAS_X25519 and self.type == 'X25519':
self.module.fail_json(msg='Your cryptography version does not support X25519')
if not CRYPTOGRAPHY_HAS_X25519_FULL and self.type == 'X25519':
self.module.fail_json(msg='Your cryptography version does not support X25519 serialization')
if not CRYPTOGRAPHY_HAS_X448 and self.type == 'X448':
self.module.fail_json(msg='Your cryptography version does not support X448')
if not CRYPTOGRAPHY_HAS_ED25519 and self.type == 'Ed25519':
self.module.fail_json(msg='Your cryptography version does not support Ed25519')
if not CRYPTOGRAPHY_HAS_ED448 and self.type == 'Ed448':
self.module.fail_json(msg='Your cryptography version does not support Ed448')
def _get_wanted_format(self):
if self.format not in ('auto', 'auto_ignore'):
return self.format
if self.type in ('X25519', 'X448', 'Ed25519', 'Ed448'):
return 'pkcs8'
else:
return 'pkcs1'
def generate_private_key(self):
"""(Re-)Generate private key."""
try:
if self.type == 'RSA':
self.private_key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
backend=self.cryptography_backend
)
if self.type == 'DSA':
self.private_key = cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size,
backend=self.cryptography_backend
)
if CRYPTOGRAPHY_HAS_X25519_FULL and self.type == 'X25519':
self.private_key = cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
if CRYPTOGRAPHY_HAS_X448 and self.type == 'X448':
self.private_key = cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
if CRYPTOGRAPHY_HAS_ED25519 and self.type == 'Ed25519':
self.private_key = cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
if CRYPTOGRAPHY_HAS_ED448 and self.type == 'Ed448':
self.private_key = cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
if self.type == 'ECC' and self.curve in self.curves:
if self.curves[self.curve]['deprecated']:
self.module.warn('Elliptic curves of type {0} should not be used for new keys!'.format(self.curve))
self.private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve]['create'](self.size),
backend=self.cryptography_backend
)
except cryptography.exceptions.UnsupportedAlgorithm as dummy:
self.module.fail_json(msg='Cryptography backend does not support the algorithm required for {0}'.format(self.type))
def get_private_key_data(self):
"""Return bytes for self.private_key"""
# Select export format and encoding
try:
export_format = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format == 'pkcs1':
# "TraditionalOpenSSL" format is PKCS1
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
elif export_format == 'pkcs8':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
elif export_format == 'raw':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.Raw
except AttributeError:
self.module.fail_json(msg='Cryptography backend does not support the selected output format "{0}"'.format(self.format))
# Select key encryption
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase:
if self.cipher == 'auto':
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(to_bytes(self.passphrase))
else:
self.module.fail_json(msg='Cryptography backend can only use "auto" for cipher option.')
# Serialize key
try:
return self.private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm
)
except ValueError as dummy:
self.module.fail_json(
msg='Cryptography backend cannot serialize the private key in the required format "{0}"'.format(self.format)
)
except Exception as dummy:
self.module.fail_json(
msg='Error while serializing the private key in the required format "{0}"'.format(self.format),
exception=traceback.format_exc()
)
def _load_privatekey(self):
data = self.existing_private_key_bytes
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
if format == 'raw':
if len(data) == 56 and CRYPTOGRAPHY_HAS_X448:
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(data)
if len(data) == 57 and CRYPTOGRAPHY_HAS_ED448:
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(data)
if len(data) == 32:
if CRYPTOGRAPHY_HAS_X25519 and (self.type == 'X25519' or not CRYPTOGRAPHY_HAS_ED25519):
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
if CRYPTOGRAPHY_HAS_ED25519 and (self.type == 'Ed25519' or not CRYPTOGRAPHY_HAS_X25519):
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
if CRYPTOGRAPHY_HAS_X25519 and CRYPTOGRAPHY_HAS_ED25519:
try:
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
except Exception:
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
raise PrivateKeyError('Cannot load raw key')
else:
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend
)
except Exception as e:
raise PrivateKeyError(e)
def _ensure_existing_private_key_loaded(self):
if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey()
def _check_passphrase(self):
try:
format = identify_private_key_format(self.existing_private_key_bytes)
if format == 'raw':
# Raw keys cannot be encrypted. To avoid incompatibilities, we try to
# actually load the key (and return False when this fails).
self._load_privatekey()
# Loading the key succeeded. Only return True when no passphrase was
# provided.
return self.passphrase is None
else:
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend
)
except Exception as dummy:
return False
def _check_size_and_type(self):
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
return self.type == 'RSA' and self.size == self.existing_private_key.key_size
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
return self.type == 'DSA' and self.size == self.existing_private_key.key_size
if CRYPTOGRAPHY_HAS_X25519 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey):
return self.type == 'X25519'
if CRYPTOGRAPHY_HAS_X448 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey):
return self.type == 'X448'
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey):
return self.type == 'Ed25519'
if CRYPTOGRAPHY_HAS_ED448 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
return self.type == 'Ed448'
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
if self.type != 'ECC':
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve]['verify'](self.existing_private_key)
return False
def _check_format(self):
if self.format == 'auto_ignore':
return True
try:
format = identify_private_key_format(self.existing_private_key_bytes)
return format == self._get_wanted_format()
except Exception as dummy:
return False
def select_backend(module, backend):
if backend == 'auto':
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_pyopenssl = PYOPENSSL_FOUND and PYOPENSSL_VERSION >= LooseVersion(MINIMAL_PYOPENSSL_VERSION)
# Decision
if module.params['cipher'] and module.params['passphrase'] and module.params['cipher'] != 'auto':
# First try pyOpenSSL, then cryptography
if can_use_pyopenssl:
backend = 'pyopenssl'
elif can_use_cryptography:
backend = 'cryptography'
else:
# First try cryptography, then pyOpenSSL
if can_use_cryptography:
backend = 'cryptography'
elif can_use_pyopenssl:
backend = 'pyopenssl'
# Success?
if backend == 'auto':
module.fail_json(msg=("Can't detect any of the required Python libraries "
"cryptography (>= {0}) or PyOpenSSL (>= {1})").format(
MINIMAL_CRYPTOGRAPHY_VERSION,
MINIMAL_PYOPENSSL_VERSION))
if backend == 'pyopenssl':
if not PYOPENSSL_FOUND:
module.fail_json(msg=missing_required_lib('pyOpenSSL >= {0}'.format(MINIMAL_PYOPENSSL_VERSION)),
exception=PYOPENSSL_IMP_ERR)
module.deprecate('The module is using the PyOpenSSL backend. This backend has been deprecated',
version='2.0.0', collection_name='community.crypto')
return backend, PrivateKeyPyOpenSSLBackend(module)
elif backend == 'cryptography':
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
return backend, PrivateKeyCryptographyBackend(module)
else:
raise Exception('Unsupported value for backend: {0}'.format(backend))
def get_privatekey_argument_spec():
return ArgumentSpec(
argument_spec=dict(
size=dict(type='int', default=4096),
type=dict(type='str', default='RSA', choices=[
'DSA', 'ECC', 'Ed25519', 'Ed448', 'RSA', 'X25519', 'X448'
]),
curve=dict(type='str', choices=[
'secp384r1', 'secp521r1', 'secp224r1', 'secp192r1', 'secp256r1',
'secp256k1', 'brainpoolP256r1', 'brainpoolP384r1', 'brainpoolP512r1',
'sect571k1', 'sect409k1', 'sect283k1', 'sect233k1', 'sect163k1',
'sect571r1', 'sect409r1', 'sect283r1', 'sect233r1', 'sect163r2',
]),
passphrase=dict(type='str', no_log=True),
cipher=dict(type='str'),
format=dict(type='str', default='auto_ignore', choices=['pkcs1', 'pkcs8', 'raw', 'auto', 'auto_ignore']),
format_mismatch=dict(type='str', default='regenerate', choices=['regenerate', 'convert']),
select_crypto_backend=dict(type='str', choices=['auto', 'pyopenssl', 'cryptography'], default='auto'),
regenerate=dict(
type='str',
default='full_idempotence',
choices=['never', 'fail', 'partial_idempotence', 'full_idempotence', 'always']
),
),
required_together=[
['cipher', 'passphrase']
],
required_if=[
['type', 'ECC', ['curve']],
],
)

View File

@@ -83,11 +83,9 @@ def get_fingerprint_of_bytes(source):
return fingerprint
def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'):
def get_fingerprint_of_privatekey(privatekey, backend='pyopenssl'):
"""Generate the fingerprint of the public key. """
privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend)
if backend == 'pyopenssl':
try:
publickey = crypto.dump_publickey(crypto.FILETYPE_ASN1, privatekey)
@@ -112,6 +110,14 @@ def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'):
return get_fingerprint_of_bytes(publickey)
def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'):
"""Generate the fingerprint of the public key. """
privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend)
return get_fingerprint_of_privatekey(privatekey, backend=backend)
def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='pyopenssl'):
"""Load the specified OpenSSL private key.
@@ -343,6 +349,10 @@ class OpenSSLObject(object):
def remove(self, module):
"""Remove the resource from the filesystem."""
if self.check_mode:
if os.path.exists(self.path):
self.changed = True
return
try:
os.remove(self.path)