Refactor common network shared and platform utils code into package (#33452)

* Refactor common network shared and platform specific code into package (part-1)

As per proposal #76 refactor common network shared and platform specific
code into sub-package.
https://github.com/ansible/proposals/issues/76

*  ansible.module_utils.network.common - command shared functions
*  ansible.module_utils.network.{{ platform }} - where platform is platform specific shared functions

*  Fix review comments

* Fix review comments
This commit is contained in:
Ganesh Nalawade
2017-12-03 21:42:30 +05:30
committed by GitHub
parent 18aca48075
commit 11c9ad23d5
483 changed files with 871 additions and 887 deletions

View File

@@ -0,0 +1,444 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# (c) 2016 Red Hat Inc.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import re
import hashlib
from ansible.module_utils.six.moves import zip
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.network.common.utils import to_list
DEFAULT_COMMENT_TOKENS = ['#', '!', '/*', '*/', 'echo']
DEFAULT_IGNORE_LINES_RE = set([
re.compile(r"Using \d+ out of \d+ bytes"),
re.compile(r"Building configuration"),
re.compile(r"Current configuration : \d+ bytes")
])
class ConfigLine(object):
def __init__(self, raw):
self.text = str(raw).strip()
self.raw = raw
self._children = list()
self._parents = list()
def __str__(self):
return self.raw
def __eq__(self, other):
return self.line == other.line
def __ne__(self, other):
return not self.__eq__(other)
def __getitem__(self, key):
for item in self._children:
if item.text == key:
return item
raise KeyError(key)
@property
def line(self):
line = self.parents
line.append(self.text)
return ' '.join(line)
@property
def children(self):
return _obj_to_text(self._children)
@property
def child_objs(self):
return self._children
@property
def parents(self):
return _obj_to_text(self._parents)
@property
def path(self):
config = _obj_to_raw(self._parents)
config.append(self.raw)
return '\n'.join(config)
@property
def has_children(self):
return len(self._children) > 0
@property
def has_parents(self):
return len(self._parents) > 0
def add_child(self, obj):
if not isinstance(obj, ConfigLine):
raise AssertionError('child must be of type `ConfigLine`')
self._children.append(obj)
def ignore_line(text, tokens=None):
for item in (tokens or DEFAULT_COMMENT_TOKENS):
if text.startswith(item):
return True
for regex in DEFAULT_IGNORE_LINES_RE:
if regex.match(text):
return True
def _obj_to_text(x):
return [o.text for o in x]
def _obj_to_raw(x):
return [o.raw for o in x]
def _obj_to_block(objects, visited=None):
items = list()
for o in objects:
if o not in items:
items.append(o)
for child in o._children:
if child not in items:
items.append(child)
return _obj_to_raw(items)
def dumps(objects, output='block', comments=False):
if output == 'block':
items = _obj_to_block(objects)
elif output == 'commands':
items = _obj_to_text(objects)
else:
raise TypeError('unknown value supplied for keyword output')
if output != 'commands':
if comments:
for index, item in enumerate(items):
nextitem = index + 1
if nextitem < len(items) and not item.startswith(' ') and items[nextitem].startswith(' '):
item = '!\n%s' % item
items[index] = item
items.append('!')
items.append('end')
return '\n'.join(items)
class NetworkConfig(object):
def __init__(self, indent=1, contents=None, ignore_lines=None):
self._indent = indent
self._items = list()
self._config_text = None
if ignore_lines:
for item in ignore_lines:
if not isinstance(item, re._pattern_type):
item = re.compile(item)
DEFAULT_IGNORE_LINES_RE.add(item)
if contents:
self.load(contents)
@property
def items(self):
return self._items
@property
def config_text(self):
return self._config_text
@property
def sha1(self):
sha1 = hashlib.sha1()
sha1.update(to_bytes(str(self), errors='surrogate_or_strict'))
return sha1.digest()
def __getitem__(self, key):
for line in self:
if line.text == key:
return line
raise KeyError(key)
def __iter__(self):
return iter(self._items)
def __str__(self):
return '\n'.join([c.raw for c in self.items])
def __len__(self):
return len(self._items)
def load(self, s):
self._config_text = s
self._items = self.parse(s)
def loadfp(self, fp):
return self.load(open(fp).read())
def parse(self, lines, comment_tokens=None):
toplevel = re.compile(r'\S')
childline = re.compile(r'^\s*(.+)$')
entry_reg = re.compile(r'([{};])')
ancestors = list()
config = list()
curlevel = 0
prevlevel = 0
for linenum, line in enumerate(to_native(lines, errors='surrogate_or_strict').split('\n')):
text = entry_reg.sub('', line).strip()
cfg = ConfigLine(line)
if not text or ignore_line(text, comment_tokens):
continue
# handle top level commands
if toplevel.match(line):
ancestors = [cfg]
prevlevel = curlevel
curlevel = 0
# handle sub level commands
else:
match = childline.match(line)
line_indent = match.start(1)
prevlevel = curlevel
curlevel = int(line_indent / self._indent)
if (curlevel - 1) > prevlevel:
curlevel = prevlevel + 1
parent_level = curlevel - 1
cfg._parents = ancestors[:curlevel]
if curlevel > len(ancestors):
config.append(cfg)
continue
for i in range(curlevel, len(ancestors)):
ancestors.pop()
ancestors.append(cfg)
ancestors[parent_level].add_child(cfg)
config.append(cfg)
return config
def get_object(self, path):
for item in self.items:
if item.text == path[-1]:
if item.parents == path[:-1]:
return item
def get_block(self, path):
if not isinstance(path, list):
raise AssertionError('path argument must be a list object')
obj = self.get_object(path)
if not obj:
raise ValueError('path does not exist in config')
return self._expand_block(obj)
def get_block_config(self, path):
block = self.get_block(path)
return dumps(block, 'block')
def _expand_block(self, configobj, S=None):
if S is None:
S = list()
S.append(configobj)
for child in configobj._children:
if child in S:
continue
self._expand_block(child, S)
return S
def _diff_line(self, other):
updates = list()
for item in self.items:
if item not in other:
updates.append(item)
return updates
def _diff_strict(self, other):
updates = list()
for index, line in enumerate(self.items):
try:
if str(line).strip() != str(other[index]).strip():
updates.append(line)
except (AttributeError, IndexError):
updates.append(line)
return updates
def _diff_exact(self, other):
updates = list()
if len(other) != len(self.items):
updates.extend(self.items)
else:
for ours, theirs in zip(self.items, other):
if ours != theirs:
updates.extend(self.items)
break
return updates
def difference(self, other, match='line', path=None, replace=None):
"""Perform a config diff against the another network config
:param other: instance of NetworkConfig to diff against
:param match: type of diff to perform. valid values are 'line',
'strict', 'exact'
:param path: context in the network config to filter the diff
:param replace: the method used to generate the replacement lines.
valid values are 'block', 'line'
:returns: a string of lines that are different
"""
if path and match != 'line':
try:
other = other.get_block(path)
except ValueError:
other = list()
else:
other = other.items
# generate a list of ConfigLines that aren't in other
meth = getattr(self, '_diff_%s' % match)
updates = meth(other)
if replace == 'block':
parents = list()
for item in updates:
if not item.has_parents:
parents.append(item)
else:
for p in item._parents:
if p not in parents:
parents.append(p)
updates = list()
for item in parents:
updates.extend(self._expand_block(item))
visited = set()
expanded = list()
for item in updates:
for p in item._parents:
if p.line not in visited:
visited.add(p.line)
expanded.append(p)
expanded.append(item)
visited.add(item.line)
return expanded
def add(self, lines, parents=None):
ancestors = list()
offset = 0
obj = None
# global config command
if not parents:
for line in lines:
item = ConfigLine(line)
item.raw = line
if item not in self.items:
self.items.append(item)
else:
for index, p in enumerate(parents):
try:
i = index + 1
obj = self.get_block(parents[:i])[0]
ancestors.append(obj)
except ValueError:
# add parent to config
offset = index * self._indent
obj = ConfigLine(p)
obj.raw = p.rjust(len(p) + offset)
if ancestors:
obj._parents = list(ancestors)
ancestors[-1]._children.append(obj)
self.items.append(obj)
ancestors.append(obj)
# add child objects
for line in lines:
# check if child already exists
for child in ancestors[-1]._children:
if child.text == line:
break
else:
offset = len(parents) * self._indent
item = ConfigLine(line)
item.raw = line.rjust(len(line) + offset)
item._parents = ancestors
ancestors[-1]._children.append(item)
self.items.append(item)
class CustomNetworkConfig(NetworkConfig):
def items_text(self):
return [item.text for item in self.items]
def expand_section(self, configobj, S=None):
if S is None:
S = list()
S.append(configobj)
for child in configobj.child_objs:
if child in S:
continue
self.expand_section(child, S)
return S
def to_block(self, section):
return '\n'.join([item.raw for item in section])
def get_section(self, path):
try:
section = self.get_section_objects(path)
return self.to_block(section)
except ValueError:
return list()
def get_section_objects(self, path):
if not isinstance(path, list):
path = [path]
obj = self.get_object(path)
if not obj:
raise ValueError('path does not exist in config')
return self.expand_section(obj)

View File

@@ -0,0 +1,87 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# (c) 2017 Red Hat Inc.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from ansible.module_utils._text import to_text, to_native
from ansible.module_utils.connection import Connection, ConnectionError
try:
from lxml.etree import Element, fromstring
except ImportError:
from xml.etree.ElementTree import Element, fromstring
NS_MAP = {'nc': "urn:ietf:params:xml:ns:netconf:base:1.0"}
def exec_rpc(module, *args, **kwargs):
connection = NetconfConnection(module._socket_path)
return connection.execute_rpc(*args, **kwargs)
class NetconfConnection(Connection):
def __init__(self, socket_path):
super(NetconfConnection, self).__init__(socket_path)
def __rpc__(self, name, *args, **kwargs):
"""Executes the json-rpc and returns the output received
from remote device.
:name: rpc method to be executed over connection plugin that implements jsonrpc 2.0
:args: Ordered list of params passed as arguments to rpc method
:kwargs: Dict of valid key, value pairs passed as arguments to rpc method
For usage refer the respective connection plugin docs.
"""
self.check_rc = kwargs.pop('check_rc', True)
self.ignore_warning = kwargs.pop('ignore_warning', True)
response = self._exec_jsonrpc(name, *args, **kwargs)
if 'error' in response:
rpc_error = response['error'].get('data')
return self.parse_rpc_error(to_native(rpc_error, errors='surrogate_then_replace'))
return fromstring(to_native(response['result'], errors='surrogate_then_replace'))
def parse_rpc_error(self, rpc_error):
if self.check_rc:
error_root = fromstring(rpc_error)
root = Element('root')
root.append(error_root)
error_list = root.findall('.//nc:rpc-error', NS_MAP)
if not error_list:
raise ConnectionError(to_text(rpc_error, errors='surrogate_then_replace'))
warnings = []
for error in error_list:
message = error.find('./nc:error-message', NS_MAP).text
severity = error.find('./nc:error-severity', NS_MAP).text
if severity == 'warning' and self.ignore_warning:
warnings.append(message)
else:
raise ConnectionError(to_text(rpc_error, errors='surrogate_then_replace'))
return warnings

View File

@@ -0,0 +1,203 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com>
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import traceback
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.network.common.parsing import Cli
from ansible.module_utils._text import to_native
from ansible.module_utils.six import iteritems
NET_TRANSPORT_ARGS = dict(
host=dict(required=True),
port=dict(type='int'),
username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])),
password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])),
ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'),
authorize=dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'),
auth_pass=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])),
provider=dict(type='dict', no_log=True),
transport=dict(choices=list()),
timeout=dict(default=10, type='int')
)
NET_CONNECTION_ARGS = dict()
NET_CONNECTIONS = dict()
def _transitional_argument_spec():
argument_spec = {}
for key, value in iteritems(NET_TRANSPORT_ARGS):
value['required'] = False
argument_spec[key] = value
return argument_spec
def to_list(val):
if isinstance(val, (list, tuple)):
return list(val)
elif val is not None:
return [val]
else:
return list()
class ModuleStub(object):
def __init__(self, argument_spec, fail_json):
self.params = dict()
for key, value in argument_spec.items():
self.params[key] = value.get('default')
self.fail_json = fail_json
class NetworkError(Exception):
def __init__(self, msg, **kwargs):
super(NetworkError, self).__init__(msg)
self.kwargs = kwargs
class Config(object):
def __init__(self, connection):
self.connection = connection
def __call__(self, commands, **kwargs):
lines = to_list(commands)
return self.connection.configure(lines, **kwargs)
def load_config(self, commands, **kwargs):
commands = to_list(commands)
return self.connection.load_config(commands, **kwargs)
def get_config(self, **kwargs):
return self.connection.get_config(**kwargs)
def save_config(self):
return self.connection.save_config()
class NetworkModule(AnsibleModule):
def __init__(self, *args, **kwargs):
connect_on_load = kwargs.pop('connect_on_load', True)
argument_spec = NET_TRANSPORT_ARGS.copy()
argument_spec['transport']['choices'] = NET_CONNECTIONS.keys()
argument_spec.update(NET_CONNECTION_ARGS.copy())
if kwargs.get('argument_spec'):
argument_spec.update(kwargs['argument_spec'])
kwargs['argument_spec'] = argument_spec
super(NetworkModule, self).__init__(*args, **kwargs)
self.connection = None
self._cli = None
self._config = None
try:
transport = self.params['transport'] or '__default__'
cls = NET_CONNECTIONS[transport]
self.connection = cls()
except KeyError:
self.fail_json(msg='Unknown transport or no default transport specified')
except (TypeError, NetworkError) as exc:
self.fail_json(msg=to_native(exc), exception=traceback.format_exc())
if connect_on_load:
self.connect()
@property
def cli(self):
if not self.connected:
self.connect()
if self._cli:
return self._cli
self._cli = Cli(self.connection)
return self._cli
@property
def config(self):
if not self.connected:
self.connect()
if self._config:
return self._config
self._config = Config(self.connection)
return self._config
@property
def connected(self):
return self.connection._connected
def _load_params(self):
super(NetworkModule, self)._load_params()
provider = self.params.get('provider') or dict()
for key, value in provider.items():
for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]:
if key in args:
if self.params.get(key) is None and value is not None:
self.params[key] = value
def connect(self):
try:
if not self.connected:
self.connection.connect(self.params)
if self.params['authorize']:
self.connection.authorize(self.params)
self.log('connected to %s:%s using %s' % (self.params['host'],
self.params['port'], self.params['transport']))
except NetworkError as exc:
self.fail_json(msg=to_native(exc), exception=traceback.format_exc())
def disconnect(self):
try:
if self.connected:
self.connection.disconnect()
self.log('disconnected from %s' % self.params['host'])
except NetworkError as exc:
self.fail_json(msg=to_native(exc), exception=traceback.format_exc())
def register_transport(transport, default=False):
def register(cls):
NET_CONNECTIONS[transport] = cls
if default:
NET_CONNECTIONS['__default__'] = cls
return cls
return register
def add_argument(key, value):
NET_CONNECTION_ARGS[key] = value

View File

@@ -0,0 +1,295 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com>
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import re
import shlex
import time
from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE, BOOLEANS_FALSE
from ansible.module_utils.six import string_types, text_type
from ansible.module_utils.six.moves import zip
def to_list(val):
if isinstance(val, (list, tuple)):
return list(val)
elif val is not None:
return [val]
else:
return list()
class FailedConditionsError(Exception):
def __init__(self, msg, failed_conditions):
super(FailedConditionsError, self).__init__(msg)
self.failed_conditions = failed_conditions
class FailedConditionalError(Exception):
def __init__(self, msg, failed_conditional):
super(FailedConditionalError, self).__init__(msg)
self.failed_conditional = failed_conditional
class AddCommandError(Exception):
def __init__(self, msg, command):
super(AddCommandError, self).__init__(msg)
self.command = command
class AddConditionError(Exception):
def __init__(self, msg, condition):
super(AddConditionError, self).__init__(msg)
self.condition = condition
class Cli(object):
def __init__(self, connection):
self.connection = connection
self.default_output = connection.default_output or 'text'
self._commands = list()
@property
def commands(self):
return [str(c) for c in self._commands]
def __call__(self, commands, output=None):
objects = list()
for cmd in to_list(commands):
objects.append(self.to_command(cmd, output))
return self.connection.run_commands(objects)
def to_command(self, command, output=None, prompt=None, response=None, **kwargs):
output = output or self.default_output
if isinstance(command, Command):
return command
if isinstance(prompt, string_types):
prompt = re.compile(re.escape(prompt))
return Command(command, output, prompt=prompt, response=response, **kwargs)
def add_commands(self, commands, output=None, **kwargs):
for cmd in commands:
self._commands.append(self.to_command(cmd, output, **kwargs))
def run_commands(self):
responses = self.connection.run_commands(self._commands)
for resp, cmd in zip(responses, self._commands):
cmd.response = resp
# wipe out the commands list to avoid issues if additional
# commands are executed later
self._commands = list()
return responses
class Command(object):
def __init__(self, command, output=None, prompt=None, response=None,
**kwargs):
self.command = command
self.output = output
self.command_string = command
self.prompt = prompt
self.response = response
self.args = kwargs
def __str__(self):
return self.command_string
class CommandRunner(object):
def __init__(self, module):
self.module = module
self.items = list()
self.conditionals = set()
self.commands = list()
self.retries = 10
self.interval = 1
self.match = 'all'
self._default_output = module.connection.default_output
def add_command(self, command, output=None, prompt=None, response=None,
**kwargs):
if command in [str(c) for c in self.commands]:
raise AddCommandError('duplicated command detected', command=command)
cmd = self.module.cli.to_command(command, output=output, prompt=prompt,
response=response, **kwargs)
self.commands.append(cmd)
def get_command(self, command, output=None):
for cmd in self.commands:
if cmd.command == command:
return cmd.response
raise ValueError("command '%s' not found" % command)
def get_responses(self):
return [cmd.response for cmd in self.commands]
def add_conditional(self, condition):
try:
self.conditionals.add(Conditional(condition))
except AttributeError as exc:
raise AddConditionError(msg=str(exc), condition=condition)
def run(self):
while self.retries > 0:
self.module.cli.add_commands(self.commands)
responses = self.module.cli.run_commands()
for item in list(self.conditionals):
if item(responses):
if self.match == 'any':
return item
self.conditionals.remove(item)
if not self.conditionals:
break
time.sleep(self.interval)
self.retries -= 1
else:
failed_conditions = [item.raw for item in self.conditionals]
errmsg = 'One or more conditional statements have not been satisfied'
raise FailedConditionsError(errmsg, failed_conditions)
class Conditional(object):
"""Used in command modules to evaluate waitfor conditions
"""
OPERATORS = {
'eq': ['eq', '=='],
'neq': ['neq', 'ne', '!='],
'gt': ['gt', '>'],
'ge': ['ge', '>='],
'lt': ['lt', '<'],
'le': ['le', '<='],
'contains': ['contains'],
'matches': ['matches']
}
def __init__(self, conditional, encoding=None):
self.raw = conditional
try:
key, op, val = shlex.split(conditional)
except ValueError:
raise ValueError('failed to parse conditional')
self.key = key
self.func = self._func(op)
self.value = self._cast_value(val)
def __call__(self, data):
value = self.get_value(dict(result=data))
return self.func(value)
def _cast_value(self, value):
if value in BOOLEANS_TRUE:
return True
elif value in BOOLEANS_FALSE:
return False
elif re.match(r'^\d+\.d+$', value):
return float(value)
elif re.match(r'^\d+$', value):
return int(value)
else:
return text_type(value)
def _func(self, oper):
for func, operators in self.OPERATORS.items():
if oper in operators:
return getattr(self, func)
raise AttributeError('unknown operator: %s' % oper)
def get_value(self, result):
try:
return self.get_json(result)
except (IndexError, TypeError, AttributeError):
msg = 'unable to apply conditional to result'
raise FailedConditionalError(msg, self.raw)
def get_json(self, result):
string = re.sub(r"\[[\'|\"]", ".", self.key)
string = re.sub(r"[\'|\"]\]", ".", string)
parts = re.split(r'\.(?=[^\]]*(?:\[|$))', string)
for part in parts:
match = re.findall(r'\[(\S+?)\]', part)
if match:
key = part[:part.find('[')]
result = result[key]
for m in match:
try:
m = int(m)
except ValueError:
m = str(m)
result = result[m]
else:
result = result.get(part)
return result
def number(self, value):
if '.' in str(value):
return float(value)
else:
return int(value)
def eq(self, value):
return value == self.value
def neq(self, value):
return value != self.value
def gt(self, value):
return self.number(value) > self.value
def ge(self, value):
return self.number(value) >= self.value
def lt(self, value):
return self.number(value) < self.value
def le(self, value):
return self.number(value) <= self.value
def contains(self, value):
return str(self.value) in value
def matches(self, value):
match = re.search(self.value, value, re.M)
return match is not None

View File

@@ -0,0 +1,428 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# (c) 2016 Red Hat Inc.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import re
import ast
import operator
import socket
from itertools import chain
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils.basic import AnsibleFallbackNotFound
try:
from jinja2 import Environment, StrictUndefined
from jinja2.exceptions import UndefinedError
HAS_JINJA2 = True
except ImportError:
HAS_JINJA2 = False
OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le'])
ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')])
def to_list(val):
if isinstance(val, (list, tuple, set)):
return list(val)
elif val is not None:
return [val]
else:
return list()
def sort_list(val):
if isinstance(val, list):
return sorted(val)
return val
class Entity(object):
"""Transforms a dict to with an argument spec
This class will take a dict and apply an Ansible argument spec to the
values. The resulting dict will contain all of the keys in the param
with appropriate values set.
Example::
argument_spec = dict(
command=dict(key=True),
display=dict(default='text', choices=['text', 'json']),
validate=dict(type='bool')
)
transform = Entity(module, argument_spec)
value = dict(command='foo')
result = transform(value)
print result
{'command': 'foo', 'display': 'text', 'validate': None}
Supported argument spec:
* key - specifies how to map a single value to a dict
* read_from - read and apply the argument_spec from the module
* required - a value is required
* type - type of value (uses AnsibleModule type checker)
* fallback - implements fallback function
* choices - set of valid options
* default - default value
"""
def __init__(self, module, attrs=None, args=None, keys=None, from_argspec=False):
args = [] if args is None else args
self._attributes = attrs or {}
self._module = module
for arg in args:
self._attributes[arg] = dict()
if from_argspec:
self._attributes[arg]['read_from'] = arg
if keys and arg in keys:
self._attributes[arg]['key'] = True
self.attr_names = frozenset(self._attributes.keys())
_has_key = False
for name, attr in iteritems(self._attributes):
if attr.get('read_from'):
if attr['read_from'] not in self._module.argument_spec:
module.fail_json(msg='argument %s does not exist' % attr['read_from'])
spec = self._module.argument_spec.get(attr['read_from'])
for key, value in iteritems(spec):
if key not in attr:
attr[key] = value
if attr.get('key'):
if _has_key:
module.fail_json(msg='only one key value can be specified')
_has_key = True
attr['required'] = True
def serialize(self):
return self._attributes
def to_dict(self, value):
obj = {}
for name, attr in iteritems(self._attributes):
if attr.get('key'):
obj[name] = value
else:
obj[name] = attr.get('default')
return obj
def __call__(self, value, strict=True):
if not isinstance(value, dict):
value = self.to_dict(value)
if strict:
unknown = set(value).difference(self.attr_names)
if unknown:
self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown))
for name, attr in iteritems(self._attributes):
if value.get(name) is None:
value[name] = attr.get('default')
if attr.get('fallback') and not value.get(name):
fallback = attr.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
value[name] = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
if attr.get('required') and value.get(name) is None:
self._module.fail_json(msg='missing required attribute %s' % name)
if 'choices' in attr:
if value[name] not in attr['choices']:
self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
if value[name] is not None:
value_type = attr.get('type', 'str')
type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
type_checker(value[name])
elif value.get(name):
value[name] = self._module.params[name]
return value
class EntityCollection(Entity):
"""Extends ```Entity``` to handle a list of dicts """
def __call__(self, iterable, strict=True):
if iterable is None:
iterable = [super(EntityCollection, self).__call__(self._module.params, strict)]
if not isinstance(iterable, (list, tuple)):
self._module.fail_json(msg='value must be an iterable')
return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable]
# these two are for backwards compatibility and can be removed once all of the
# modules that use them are updated
class ComplexDict(Entity):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)
class ComplexList(EntityCollection):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexList, self).__init__(module, attrs, *args, **kwargs)
def dict_diff(base, comparable):
""" Generate a dict object of differences
This function will compare two dict objects and return the difference
between them as a dict object. For scalar values, the key will reflect
the updated value. If the key does not exist in `comparable`, then then no
key will be returned. For lists, the value in comparable will wholly replace
the value in base for the key. For dicts, the returned value will only
return keys that are different.
:param base: dict object to base the diff on
:param comparable: dict object to compare against base
:returns: new dict object with differences
"""
if not isinstance(base, dict):
raise AssertionError("`base` must be of type <dict>")
if not isinstance(comparable, dict):
raise AssertionError("`comparable` must be of type <dict>")
updates = dict()
for key, value in iteritems(base):
if isinstance(value, dict):
item = comparable.get(key)
if item is not None:
updates[key] = dict_diff(value, comparable[key])
else:
comparable_value = comparable.get(key)
if comparable_value is not None:
if sort_list(base[key]) != sort_list(comparable_value):
updates[key] = comparable_value
for key in set(comparable.keys()).difference(base.keys()):
updates[key] = comparable.get(key)
return updates
def dict_merge(base, other):
""" Return a new dict object that combines base and other
This will create a new dict object that is a combination of the key/value
pairs from base and other. When both keys exist, the value will be
selected from other. If the value is a list object, the two lists will
be combined and duplicate entries removed.
:param base: dict object to serve as base
:param other: dict object to combine with base
:returns: new combined dict object
"""
if not isinstance(base, dict):
raise AssertionError("`base` must be of type <dict>")
if not isinstance(other, dict):
raise AssertionError("`other` must be of type <dict>")
combined = dict()
for key, value in iteritems(base):
if isinstance(value, dict):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = dict_merge(value, other[key])
else:
combined[key] = item
else:
combined[key] = value
elif isinstance(value, list):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = list(set(chain(value, item)))
else:
combined[key] = item
else:
combined[key] = value
else:
if key in other:
other_value = other.get(key)
if other_value is not None:
if sort_list(base[key]) != sort_list(other_value):
combined[key] = other_value
else:
combined[key] = value
else:
combined[key] = other_value
else:
combined[key] = value
for key in set(other.keys()).difference(base.keys()):
combined[key] = other.get(key)
return combined
def conditional(expr, val, cast=None):
match = re.match(r'^(.+)\((.+)\)$', str(expr), re.I)
if match:
op, arg = match.groups()
else:
op = 'eq'
if ' ' in str(expr):
raise AssertionError('invalid expression: cannot contain spaces')
arg = expr
if cast is None and val is not None:
arg = type(val)(arg)
elif callable(cast):
arg = cast(arg)
val = cast(val)
op = next((oper for alias, oper in ALIASES if op == alias), op)
if not hasattr(operator, op) and op not in OPERATORS:
raise ValueError('unknown operator: %s' % op)
func = getattr(operator, op)
return func(val, arg)
def ternary(value, true_val, false_val):
''' value ? true_val : false_val '''
if value:
return true_val
else:
return false_val
def remove_default_spec(spec):
for item in spec:
if 'default' in spec[item]:
del spec[item]['default']
def validate_ip_address(address):
try:
socket.inet_aton(address)
except socket.error:
return False
return address.count('.') == 3
def validate_prefix(prefix):
if prefix and not 0 <= int(prefix) <= 32:
return False
return True
def load_provider(spec, args):
provider = args.get('provider', {})
for key, value in iteritems(spec):
if key not in provider:
if key in args:
provider[key] = args[key]
elif 'fallback' in value:
provider[key] = _fallback(value['fallback'])
elif 'default' in value:
provider[key] = value['default']
else:
provider[key] = None
args['provider'] = provider
return provider
def _fallback(fallback):
strategy = fallback[0]
args = []
kwargs = {}
for item in fallback[1:]:
if isinstance(item, dict):
kwargs = item
else:
args = item
try:
return strategy(*args, **kwargs)
except AnsibleFallbackNotFound:
pass
class Template:
def __init__(self):
if not HAS_JINJA2:
raise ImportError("jinja2 is required but does not appear to be installed. "
"It can be installed using `pip install jinja2`")
self.env = Environment(undefined=StrictUndefined)
self.env.filters.update({'ternary': ternary})
def __call__(self, value, variables=None, fail_on_undefined=True):
variables = variables or {}
if not self.contains_vars(value):
return value
try:
value = self.env.from_string(value).render(variables)
except UndefinedError:
if not fail_on_undefined:
return None
raise
if value:
try:
return ast.literal_eval(value)
except:
return str(value)
else:
return None
def contains_vars(self, data):
if isinstance(data, string_types):
for marker in (self.env.block_start_string, self.env.variable_start_string, self.env.comment_start_string):
if marker in data:
return True
return False