Creating playbook executor and dependent classes

This commit is contained in:
James Cammarata
2014-11-14 16:14:08 -06:00
parent b6c3670f8a
commit 62d79568be
158 changed files with 22486 additions and 2353 deletions

View File

@@ -0,0 +1,167 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import pipes
import random
from ansible import constants as C
__all__ = ['ConnectionInformation']
class ConnectionInformation:
'''
This class is used to consolidate the connection information for
hosts in a play and child tasks, where the task may override some
connection/authentication information.
'''
def __init__(self, play=None, options=None):
# FIXME: implement the new methodology here for supporting
# various different auth escalation methods (becomes, etc.)
self.connection = C.DEFAULT_TRANSPORT
self.remote_user = 'root'
self.password = ''
self.port = 22
self.su = False
self.su_user = ''
self.su_pass = ''
self.sudo = False
self.sudo_user = ''
self.sudo_pass = ''
self.verbosity = 0
self.only_tags = set()
self.skip_tags = set()
if play:
self.set_play(play)
if options:
self.set_options(options)
def set_play(self, play):
'''
Configures this connection information instance with data from
the play class.
'''
if play.connection:
self.connection = play.connection
self.remote_user = play.remote_user
self.password = ''
self.port = int(play.port) if play.port else 22
self.su = play.su
self.su_user = play.su_user
self.su_pass = play.su_pass
self.sudo = play.sudo
self.sudo_user = play.sudo_user
self.sudo_pass = play.sudo_pass
def set_options(self, options):
'''
Configures this connection information instance with data from
options specified by the user on the command line. These have a
higher precedence than those set on the play or host.
'''
# FIXME: set other values from options here?
self.verbosity = options.verbosity
if options.connection:
self.connection = options.connection
# get the tag info from options, converting a comma-separated list
# of values into a proper list if need be
if isinstance(options.tags, list):
self.only_tags.update(options.tags)
elif isinstance(options.tags, basestring):
self.only_tags.update(options.tags.split(','))
if isinstance(options.skip_tags, list):
self.skip_tags.update(options.skip_tags)
elif isinstance(options.skip_tags, basestring):
self.skip_tags.update(options.skip_tags.split(','))
def copy(self, ci):
'''
Copies the connection info from another connection info object, used
when merging in data from task overrides.
'''
self.connection = ci.connection
self.remote_user = ci.remote_user
self.password = ci.password
self.port = ci.port
self.su = ci.su
self.su_user = ci.su_user
self.su_pass = ci.su_pass
self.sudo = ci.sudo
self.sudo_user = ci.sudo_user
self.sudo_pass = ci.sudo_pass
self.verbosity = ci.verbosity
self.only_tags = ci.only_tags.copy()
self.skip_tags = ci.skip_tags.copy()
def set_task_override(self, task):
'''
Sets attributes from the task if they are set, which will override
those from the play.
'''
new_info = ConnectionInformation()
new_info.copy(self)
for attr in ('connection', 'remote_user', 'su', 'su_user', 'su_pass', 'sudo', 'sudo_user', 'sudo_pass'):
if hasattr(task, attr):
attr_val = getattr(task, attr)
if attr_val:
setattr(new_info, attr, attr_val)
return new_info
def make_sudo_cmd(self, sudo_exe, executable, cmd):
"""
Helper function for wrapping commands with sudo.
Rather than detect if sudo wants a password this time, -k makes
sudo always ask for a password if one is required. Passing a quoted
compound command to sudo (or sudo -s) directly doesn't work, so we
shellquote it with pipes.quote() and pass the quoted string to the
user's shell. We loop reading output until we see the randomly-
generated sudo prompt set with the -p option.
"""
randbits = ''.join(chr(random.randint(ord('a'), ord('z'))) for x in xrange(32))
prompt = '[sudo via ansible, key=%s] password: ' % randbits
success_key = 'SUDO-SUCCESS-%s' % randbits
sudocmd = '%s -k && %s %s -S -p "%s" -u %s %s -c %s' % (
sudo_exe, sudo_exe, C.DEFAULT_SUDO_FLAGS, prompt,
self.sudo_user, executable or '$SHELL',
pipes.quote('echo %s; %s' % (success_key, cmd))
)
#return ('/bin/sh -c ' + pipes.quote(sudocmd), prompt, success_key)
return (sudocmd, prompt, success_key)

View File

@@ -0,0 +1,66 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from multiprocessing.managers import SyncManager, BaseProxy
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.playbook.play import Play
from ansible.errors import AnsibleError
__all__ = ['AnsibleManager']
class VariableManagerWrapper:
'''
This class simply acts as a wrapper around the VariableManager class,
since manager proxies expect a new object to be returned rather than
any existing one. Using this wrapper, a shared proxy can be created
and an existing VariableManager class assigned to it, which can then
be accessed through the exposed proxy methods.
'''
def __init__(self):
self._vm = None
def get_vars(self, loader, play=None, host=None, task=None):
return self._vm.get_vars(loader=loader, play=play, host=host, task=task)
def set_variable_manager(self, vm):
self._vm = vm
def set_host_variable(self, host, varname, value):
self._vm.set_host_variable(host, varname, value)
def set_host_facts(self, host, facts):
self._vm.set_host_facts(host, facts)
class AnsibleManager(SyncManager):
'''
This is our custom manager class, which exists only so we may register
the new proxy below
'''
pass
AnsibleManager.register(
typeid='VariableManagerWrapper',
callable=VariableManagerWrapper,
)

View File

@@ -0,0 +1,185 @@
# (c) 2013-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# from python and deps
from cStringIO import StringIO
import inspect
import json
import os
import shlex
# from Ansible
from ansible import __version__
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.parsing.utils.jsonify import jsonify
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
REPLACER_WINDOWS = "# POWERSHELL_COMMON"
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\""
class ModuleReplacer(object):
"""
The Replacer is used to insert chunks of code into modules before
transfer. Rather than doing classical python imports, this allows for more
efficient transfer in a no-bootstrapping scenario by not moving extra files
over the wire, and also takes care of embedding arguments in the transferred
modules.
This version is done in such a way that local imports can still be
used in the module code, so IDEs don't have to be aware of what is going on.
Example:
from ansible.module_utils.basic import *
... will result in the insertion basic.py into the module
from the module_utils/ directory in the source tree.
All modules are required to import at least basic, though there will also
be other snippets.
# POWERSHELL_COMMON
Also results in the inclusion of the common code in powershell.ps1
"""
# ******************************************************************************
def __init__(self, strip_comments=False):
# FIXME: these members need to be prefixed with '_' and the rest of the file fixed
this_file = inspect.getfile(inspect.currentframe())
# we've moved the module_common relative to the snippets, so fix the path
self.snippet_path = os.path.join(os.path.dirname(this_file), '..', 'module_utils')
self.strip_comments = strip_comments
# ******************************************************************************
def slurp(self, path):
if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % path)
fd = open(path)
data = fd.read()
fd.close()
return data
def _find_snippet_imports(self, module_data, module_path):
"""
Given the source of the module, convert it to a Jinja2 template to insert
module code and return whether it's a new or old style module.
"""
module_style = 'old'
if REPLACER in module_data:
module_style = 'new'
elif 'from ansible.module_utils.' in module_data:
module_style = 'new'
elif 'WANT_JSON' in module_data:
module_style = 'non_native_want_json'
output = StringIO()
lines = module_data.split('\n')
snippet_names = []
for line in lines:
if REPLACER in line:
output.write(self.slurp(os.path.join(self.snippet_path, "basic.py")))
snippet_names.append('basic')
if REPLACER_WINDOWS in line:
ps_data = self.slurp(os.path.join(self.snippet_path, "powershell.ps1"))
output.write(ps_data)
snippet_names.append('powershell')
elif line.startswith('from ansible.module_utils.'):
tokens=line.split(".")
import_error = False
if len(tokens) != 3:
import_error = True
if " import *" not in line:
import_error = True
if import_error:
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.basic import *'" % module_path)
snippet_name = tokens[2].split()[0]
snippet_names.append(snippet_name)
output.write(self.slurp(os.path.join(self.snippet_path, snippet_name + ".py")))
else:
if self.strip_comments and line.startswith("#") or line == '':
pass
output.write(line)
output.write("\n")
if not module_path.endswith(".ps1"):
# Unixy modules
if len(snippet_names) > 0 and not 'basic' in snippet_names:
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
else:
# Windows modules
if len(snippet_names) > 0 and not 'powershell' in snippet_names:
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
return (output.getvalue(), module_style)
# ******************************************************************************
def modify_module(self, module_path, module_args):
with open(module_path) as f:
# read in the module source
module_data = f.read()
(module_data, module_style) = self._find_snippet_imports(module_data, module_path)
#module_args_json = jsonify(module_args)
module_args_json = json.dumps(module_args)
encoded_args = repr(module_args_json.encode('utf-8'))
# these strings should be part of the 'basic' snippet which is required to be included
module_data = module_data.replace(REPLACER_VERSION, repr(__version__))
module_data = module_data.replace(REPLACER_ARGS, "''")
module_data = module_data.replace(REPLACER_COMPLEX, encoded_args)
# FIXME: we're not passing around an inject dictionary anymore, so
# this needs to be fixed with whatever method we use for vars
# like this moving forward
#if module_style == 'new':
# facility = C.DEFAULT_SYSLOG_FACILITY
# if 'ansible_syslog_facility' in inject:
# facility = inject['ansible_syslog_facility']
# module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility)
lines = module_data.split("\n")
shebang = None
if lines[0].startswith("#!"):
shebang = lines[0].strip()
args = shlex.split(str(shebang[2:]))
interpreter = args[0]
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
# FIXME: more inject stuff here...
#if interpreter_config in inject:
# lines[0] = shebang = "#!%s %s" % (inject[interpreter_config], " ".join(args[1:]))
# module_data = "\n".join(lines)
return (module_data, module_style, shebang)

View File

@@ -0,0 +1,258 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.errors import *
from ansible.playbook.task import Task
from ansible.utils.boolean import boolean
__all__ = ['PlayIterator']
# the primary running states for the play iteration
ITERATING_SETUP = 0
ITERATING_TASKS = 1
ITERATING_RESCUE = 2
ITERATING_ALWAYS = 3
ITERATING_COMPLETE = 4
# the failure states for the play iteration
FAILED_NONE = 0
FAILED_SETUP = 1
FAILED_TASKS = 2
FAILED_RESCUE = 3
FAILED_ALWAYS = 4
class PlayState:
'''
A helper class, which keeps track of the task iteration
state for a given playbook. This is used in the PlaybookIterator
class on a per-host basis.
'''
# FIXME: this class is the representation of a finite state machine,
# so we really should have a well defined state representation
# documented somewhere...
def __init__(self, parent_iterator, host):
'''
Create the initial state, which tracks the running state as well
as the failure state, which are used when executing block branches
(rescue/always)
'''
self._run_state = ITERATING_SETUP
self._failed_state = FAILED_NONE
self._task_list = parent_iterator._play.compile()
self._gather_facts = parent_iterator._play.gather_facts
self._host = host
self._cur_block = None
self._cur_role = None
self._cur_task_pos = 0
self._cur_rescue_pos = 0
self._cur_always_pos = 0
self._cur_handler_pos = 0
def next(self, peek=False):
'''
Determines and returns the next available task from the playbook,
advancing through the list of plays as it goes. If peek is set to True,
the internal state is not stored.
'''
task = None
# save this locally so that we can peek at the next task
# without updating the internal state of the iterator
run_state = self._run_state
failed_state = self._failed_state
cur_block = self._cur_block
cur_role = self._cur_role
cur_task_pos = self._cur_task_pos
cur_rescue_pos = self._cur_rescue_pos
cur_always_pos = self._cur_always_pos
cur_handler_pos = self._cur_handler_pos
while True:
if run_state == ITERATING_SETUP:
if failed_state == FAILED_SETUP:
run_state = ITERATING_COMPLETE
else:
run_state = ITERATING_TASKS
if self._gather_facts == 'smart' and not self._host.gathered_facts or boolean(self._gather_facts):
self._host.set_gathered_facts(True)
task = Task()
task.action = 'setup'
break
elif run_state == ITERATING_TASKS:
# if there is any failure state besides FAILED_NONE, we should
# change to some other running state
if failed_state != FAILED_NONE or cur_task_pos > len(self._task_list) - 1:
# if there is a block (and there always should be), start running
# the rescue portion if it exists (and if we haven't failed that
# already), or the always portion (if it exists and we didn't fail
# there too). Otherwise, we're done iterating.
if cur_block:
if failed_state != FAILED_RESCUE and cur_block.rescue:
run_state = ITERATING_RESCUE
cur_rescue_pos = 0
elif failed_state != FAILED_ALWAYS and cur_block.always:
run_state = ITERATING_ALWAYS
cur_always_pos = 0
else:
run_state = ITERATING_COMPLETE
else:
run_state = ITERATING_COMPLETE
else:
task = self._task_list[cur_task_pos]
if cur_block is not None and cur_block != task._block:
run_state = ITERATING_ALWAYS
continue
else:
cur_block = task._block
cur_task_pos += 1
# Break out of the while loop now that we have our task
break
elif run_state == ITERATING_RESCUE:
# If we're iterating through the rescue tasks, make sure we haven't
# failed yet. If so, move on to the always block or if not get the
# next rescue task (if one exists)
if failed_state == FAILED_RESCUE or cur_block.rescue is None or cur_rescue_pos > len(cur_block.rescue) - 1:
run_state = ITERATING_ALWAYS
else:
task = cur_block.rescue[cur_rescue_pos]
cur_rescue_pos += 1
break
elif run_state == ITERATING_ALWAYS:
# If we're iterating through the always tasks, make sure we haven't
# failed yet. If so, we're done iterating otherwise get the next always
# task (if one exists)
if failed_state == FAILED_ALWAYS or cur_block.always is None or cur_always_pos > len(cur_block.always) - 1:
cur_block = None
if failed_state == FAILED_ALWAYS or cur_task_pos > len(self._task_list) - 1:
run_state = ITERATING_COMPLETE
else:
run_state = ITERATING_TASKS
else:
task = cur_block.always[cur_always_pos]
cur_always_pos += 1
break
elif run_state == ITERATING_COMPLETE:
# done iterating, return None to signify that
return None
if task._role:
if cur_role and task._role != cur_role:
cur_role._completed = True
cur_role = task._role
# If we're not just peeking at the next task, save the internal state
if not peek:
self._run_state = run_state
self._failed_state = failed_state
self._cur_block = cur_block
self._cur_role = cur_role
self._cur_task_pos = cur_task_pos
self._cur_rescue_pos = cur_rescue_pos
self._cur_always_pos = cur_always_pos
self._cur_handler_pos = cur_handler_pos
return task
def mark_failed(self):
'''
Escalates the failed state relative to the running state.
'''
if self._run_state == ITERATING_SETUP:
self._failed_state = FAILED_SETUP
elif self._run_state == ITERATING_TASKS:
self._failed_state = FAILED_TASKS
elif self._run_state == ITERATING_RESCUE:
self._failed_state = FAILED_RESCUE
elif self._run_state == ITERATING_ALWAYS:
self._failed_state = FAILED_ALWAYS
class PlayIterator:
'''
The main iterator class, which keeps the state of the playbook
on a per-host basis using the above PlaybookState class.
'''
def __init__(self, inventory, play):
self._play = play
self._inventory = inventory
self._host_entries = dict()
self._first_host = None
# Build the per-host dictionary of playbook states, using a copy
# of the play object so we can post_validate it to ensure any templated
# fields are filled in without modifying the original object, since
# post_validate() saves the templated values.
# FIXME: this is a hacky way of doing this, the iterator should
# instead get the loader and variable manager directly
# as args to __init__
all_vars = inventory._variable_manager.get_vars(loader=inventory._loader, play=play)
new_play = play.copy()
new_play.post_validate(all_vars, ignore_undefined=True)
for host in inventory.get_hosts(new_play.hosts):
if self._first_host is None:
self._first_host = host
self._host_entries[host.get_name()] = PlayState(parent_iterator=self, host=host)
# FIXME: remove, probably not required anymore
#def get_next_task(self, peek=False):
# ''' returns the next task for host[0] '''
#
# first_entry = self._host_entries[self._first_host.get_name()]
# if not peek:
# for entry in self._host_entries:
# if entry != self._first_host.get_name():
# target_entry = self._host_entries[entry]
# if target_entry._cur_task_pos == first_entry._cur_task_pos:
# target_entry.next()
# return first_entry.next(peek=peek)
def get_next_task_for_host(self, host, peek=False):
''' fetch the next task for the given host '''
if host.get_name() not in self._host_entries:
raise AnsibleError("invalid host (%s) specified for playbook iteration" % host)
return self._host_entries[host.get_name()].next(peek=peek)
def mark_host_failed(self, host):
''' mark the given host as failed '''
if host.get_name() not in self._host_entries:
raise AnsibleError("invalid host (%s) specified for playbook iteration" % host)
self._host_entries[host.get_name()].mark_failed()

View File

@@ -19,17 +19,110 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import signal
from ansible import constants as C
from ansible.errors import *
from ansible.executor.task_queue_manager import TaskQueueManager
from ansible.playbook import Playbook
from ansible.utils.debug import debug
class PlaybookExecutor:
def __init__(self, list_of_plays=[]):
# self.tqm = TaskQueueManager(forks)
assert False
'''
This is the primary class for executing playbooks, and thus the
basis for bin/ansible-playbook operation.
'''
def run(self):
# for play in list_of_plays:
# for block in play.blocks:
# # block must know its playbook class and context
# tqm.enqueue(block)
# tqm.go()...
assert False
def __init__(self, playbooks, inventory, variable_manager, loader, options):
self._playbooks = playbooks
self._inventory = inventory
self._variable_manager = variable_manager
self._loader = loader
self._options = options
self._tqm = TaskQueueManager(inventory=inventory, callback='default', variable_manager=variable_manager, loader=loader, options=options)
def run(self):
'''
Run the given playbook, based on the settings in the play which
may limit the runs to serialized groups, etc.
'''
signal.signal(signal.SIGINT, self._cleanup)
try:
for playbook_path in self._playbooks:
pb = Playbook.load(playbook_path, variable_manager=self._variable_manager, loader=self._loader)
# FIXME: playbook entries are just plays, so we should rename them
for play in pb.get_entries():
self._inventory.remove_restriction()
# Create a temporary copy of the play here, so we can run post_validate
# on it without the templating changes affecting the original object.
all_vars = self._variable_manager.get_vars(loader=self._loader, play=play)
new_play = play.copy()
new_play.post_validate(all_vars, ignore_undefined=True)
result = True
for batch in self._get_serialized_batches(new_play):
if len(batch) == 0:
raise AnsibleError("No hosts matched the list specified in the play", obj=play._ds)
# restrict the inventory to the hosts in the serialized batch
self._inventory.restrict_to_hosts(batch)
# and run it...
result = self._tqm.run(play=play)
if not result:
break
if not result:
# FIXME: do something here, to signify the playbook execution failed
self._cleanup()
return 1
except:
self._cleanup()
raise
self._cleanup()
return 0
def _cleanup(self, signum=None, framenum=None):
self._tqm.cleanup()
def _get_serialized_batches(self, play):
'''
Returns a list of hosts, subdivided into batches based on
the serial size specified in the play.
'''
# make sure we have a unique list of hosts
all_hosts = self._inventory.get_hosts(play.hosts)
# check to see if the serial number was specified as a percentage,
# and convert it to an integer value based on the number of hosts
if isinstance(play.serial, basestring) and play.serial.endswith('%'):
serial_pct = int(play.serial.replace("%",""))
serial = int((serial_pct/100.0) * len(all_hosts))
else:
serial = int(play.serial)
# if the serial count was not specified or is invalid, default to
# a list of all hosts, otherwise split the list of hosts into chunks
# which are based on the serial size
if serial <= 0:
return [all_hosts]
else:
serialized_batches = []
while len(all_hosts) > 0:
play_hosts = []
for x in range(serial):
if len(all_hosts) > 0:
play_hosts.append(all_hosts.pop(0))
serialized_batches.append(play_hosts)
return serialized_batches

View File

@@ -1,125 +0,0 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class PlaybookState:
'''
A helper class, which keeps track of the task iteration
state for a given playbook. This is used in the PlaybookIterator
class on a per-host basis.
'''
def __init__(self, parent_iterator):
self._parent_iterator = parent_iterator
self._cur_play = 0
self._task_list = None
self._cur_task_pos = 0
self._done = False
def next(self, peek=False):
'''
Determines and returns the next available task from the playbook,
advancing through the list of plays as it goes.
'''
task = None
# we save these locally so that we can peek at the next task
# without updating the internal state of the iterator
cur_play = self._cur_play
task_list = self._task_list
cur_task_pos = self._cur_task_pos
while True:
# when we hit the end of the playbook entries list, we set a flag
# and return None to indicate we're there
# FIXME: accessing the entries and parent iterator playbook members
# should be done through accessor functions
if self._done or cur_play > len(self._parent_iterator._playbook._entries) - 1:
self._done = True
return None
# initialize the task list by calling the .compile() method
# on the play, which will call compile() for all child objects
if task_list is None:
task_list = self._parent_iterator._playbook._entries[cur_play].compile()
# if we've hit the end of this plays task list, move on to the next
# and reset the position values for the next iteration
if cur_task_pos > len(task_list) - 1:
cur_play += 1
task_list = None
cur_task_pos = 0
continue
else:
# FIXME: do tag/conditional evaluation here and advance
# the task position if it should be skipped without
# returning a task
task = task_list[cur_task_pos]
cur_task_pos += 1
# Skip the task if it is the member of a role which has already
# been run, unless the role allows multiple executions
if task._role:
# FIXME: this should all be done via member functions
# instead of direct access to internal variables
if task._role.has_run() and not task._role._metadata._allow_duplicates:
continue
# Break out of the while loop now that we have our task
break
# If we're not just peeking at the next task, save the internal state
if not peek:
self._cur_play = cur_play
self._task_list = task_list
self._cur_task_pos = cur_task_pos
return task
class PlaybookIterator:
'''
The main iterator class, which keeps the state of the playbook
on a per-host basis using the above PlaybookState class.
'''
def __init__(self, inventory, log_manager, playbook):
self._playbook = playbook
self._log_manager = log_manager
self._host_entries = dict()
self._first_host = None
# build the per-host dictionary of playbook states
for host in inventory.get_hosts():
if self._first_host is None:
self._first_host = host
self._host_entries[host.get_name()] = PlaybookState(parent_iterator=self)
def get_next_task(self, peek=False):
''' returns the next task for host[0] '''
return self._host_entries[self._first_host.get_name()].next(peek=peek)
def get_next_task_for_host(self, host, peek=False):
''' fetch the next task for the given host '''
if host.get_name() not in self._host_entries:
raise AnsibleError("invalid host specified for playbook iteration")
return self._host_entries[host.get_name()].next(peek=peek)

View File

@@ -0,0 +1,155 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import Queue
import multiprocessing
import os
import signal
import sys
import time
import traceback
HAS_ATFORK=True
try:
from Crypto.Random import atfork
except ImportError:
HAS_ATFORK=False
from ansible.executor.task_result import TaskResult
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.utils.debug import debug
__all__ = ['ResultProcess']
class ResultProcess(multiprocessing.Process):
'''
The result worker thread, which reads results from the results
queue and fires off callbacks/etc. as necessary.
'''
def __init__(self, final_q, workers):
# takes a task queue manager as the sole param:
self._final_q = final_q
self._workers = workers
self._cur_worker = 0
self._terminated = False
super(ResultProcess, self).__init__()
def _send_result(self, result):
debug("sending result: %s" % (result,))
self._final_q.put(result, block=False)
debug("done sending result")
def _read_worker_result(self):
result = None
starting_point = self._cur_worker
while True:
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
self._cur_worker += 1
if self._cur_worker >= len(self._workers):
self._cur_worker = 0
try:
if not rslt_q.empty():
debug("worker %d has data to read" % self._cur_worker)
result = rslt_q.get(block=False)
debug("got a result from worker %d: %s" % (self._cur_worker, result))
break
except Queue.Empty:
pass
if self._cur_worker == starting_point:
break
return result
def terminate(self):
self._terminated = True
super(ResultProcess, self).terminate()
def run(self):
'''
The main thread execution, which reads from the results queue
indefinitely and sends callbacks/etc. when results are received.
'''
if HAS_ATFORK:
atfork()
while True:
try:
result = self._read_worker_result()
if result is None:
time.sleep(0.1)
continue
host_name = result._host.get_name()
# send callbacks, execute other options based on the result status
if result.is_failed():
#self._callback.runner_on_failed(result._task, result)
self._send_result(('host_task_failed', result))
elif result.is_unreachable():
#self._callback.runner_on_unreachable(result._task, result)
self._send_result(('host_unreachable', result))
elif result.is_skipped():
#self._callback.runner_on_skipped(result._task, result)
self._send_result(('host_task_skipped', result))
else:
#self._callback.runner_on_ok(result._task, result)
self._send_result(('host_task_ok', result))
# if this task is notifying a handler, do it now
if result._task.notify:
# The shared dictionary for notified handlers is a proxy, which
# does not detect when sub-objects within the proxy are modified.
# So, per the docs, we reassign the list so the proxy picks up and
# notifies all other threads
for notify in result._task.notify:
self._send_result(('notify_handler', notify, result._host))
# if this task is registering facts, do that now
if 'ansible_facts' in result._result:
if result._task.action in ('set_fact', 'include_vars'):
for (key, value) in result._result['ansible_facts'].iteritems():
self._send_result(('set_host_var', result._host, key, value))
else:
self._send_result(('set_host_facts', result._host, result._result['ansible_facts']))
# if this task is registering a result, do it now
if result._task.register:
self._send_result(('set_host_var', result._host, result._task.register, result._result))
except Queue.Empty:
pass
except (KeyboardInterrupt, IOError, EOFError):
break
except:
# FIXME: we should probably send a proper callback here instead of
# simply dumping a stack trace on the screen
traceback.print_exc()
break

View File

@@ -0,0 +1,141 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import Queue
import multiprocessing
import os
import signal
import sys
import time
import traceback
HAS_ATFORK=True
try:
from Crypto.Random import atfork
except ImportError:
HAS_ATFORK=False
from ansible.errors import AnsibleError, AnsibleConnectionFailure
from ansible.executor.task_executor import TaskExecutor
from ansible.executor.task_result import TaskResult
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.utils.debug import debug
__all__ = ['ExecutorProcess']
class WorkerProcess(multiprocessing.Process):
'''
The worker thread class, which uses TaskExecutor to run tasks
read from a job queue and pushes results into a results queue
for reading later.
'''
def __init__(self, tqm, main_q, rslt_q, loader, new_stdin):
# takes a task queue manager as the sole param:
self._main_q = main_q
self._rslt_q = rslt_q
self._loader = loader
# dupe stdin, if we have one
try:
fileno = sys.stdin.fileno()
except ValueError:
fileno = None
self._new_stdin = new_stdin
if not new_stdin and fileno is not None:
try:
self._new_stdin = os.fdopen(os.dup(fileno))
except OSError, e:
# couldn't dupe stdin, most likely because it's
# not a valid file descriptor, so we just rely on
# using the one that was passed in
pass
super(WorkerProcess, self).__init__()
def run(self):
'''
Called when the process is started, and loops indefinitely
until an error is encountered (typically an IOerror from the
queue pipe being disconnected). During the loop, we attempt
to pull tasks off the job queue and run them, pushing the result
onto the results queue. We also remove the host from the blocked
hosts list, to signify that they are ready for their next task.
'''
if HAS_ATFORK:
atfork()
while True:
task = None
try:
if not self._main_q.empty():
debug("there's work to be done!")
(host, task, job_vars, connection_info) = self._main_q.get(block=False)
debug("got a task/handler to work on: %s" % task)
new_connection_info = connection_info.set_task_override(task)
# execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task))
executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._loader).run()
debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result)
# put the result on the result queue
debug("sending task result")
self._rslt_q.put(task_result, block=False)
debug("done sending task result")
else:
time.sleep(0.1)
except Queue.Empty:
pass
except (IOError, EOFError, KeyboardInterrupt):
break
except AnsibleConnectionFailure:
try:
if task:
task_result = TaskResult(host, task, dict(unreachable=True))
self._rslt_q.put(task_result, block=False)
except:
# FIXME: most likely an abort, catch those kinds of errors specifically
break
except Exception, e:
debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc())
try:
if task:
task_result = TaskResult(host, task, dict(failed=True, exception=True, stdout=traceback.format_exc()))
self._rslt_q.put(task_result, block=False)
except:
# FIXME: most likely an abort, catch those kinds of errors specifically
break
debug("WORKER PROCESS EXITING")

View File

@@ -19,14 +19,196 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.executor.connection_info import ConnectionInformation
from ansible.plugins import lookup_loader, connection_loader, action_loader
from ansible.utils.debug import debug
__all__ = ['TaskExecutor']
import json
import time
class TaskExecutor:
def __init__(self, task, host):
pass
'''
This is the main worker class for the executor pipeline, which
handles loading an action plugin to actually dispatch the task to
a given host. This class roughly corresponds to the old Runner()
class.
'''
def run(self):
# returns TaskResult
pass
def __init__(self, host, task, job_vars, connection_info, loader):
self._host = host
self._task = task
self._job_vars = job_vars
self._connection_info = connection_info
self._loader = loader
def run(self):
'''
The main executor entrypoint, where we determine if the specified
task requires looping and either runs the task with
'''
debug("in run()")
items = self._get_loop_items()
if items:
if len(items) > 0:
item_results = self._run_loop(items)
res = dict(results=item_results)
else:
res = dict(changed=False, skipped=True, skipped_reason='No items in the list', results=[])
else:
debug("calling self._execute()")
res = self._execute()
debug("_execute() done")
debug("dumping result to json")
result = json.dumps(res)
debug("done dumping result, returning")
return result
def _get_loop_items(self):
'''
Loads a lookup plugin to handle the with_* portion of a task (if specified),
and returns the items result.
'''
items = None
if self._task.loop and self._task.loop in lookup_loader:
items = lookup_loader.get(self._task.loop).run(self._task.loop_args)
return items
def _run_loop(self, items):
'''
Runs the task with the loop items specified and collates the result
into an array named 'results' which is inserted into the final result
along with the item for which the loop ran.
'''
results = []
# FIXME: squash items into a flat list here for those modules
# which support it (yum, apt, etc.) but make it smarter
# than it is today?
for item in items:
res = self._execute()
res['item'] = item
results.append(res)
return results
def _execute(self):
'''
The primary workhorse of the executor system, this runs the task
on the specified host (which may be the delegated_to host) and handles
the retry/until and block rescue/always execution
'''
connection = self._get_connection()
handler = self._get_action_handler(connection=connection)
# check to see if this task should be skipped, due to it being a member of a
# role which has already run (and whether that role allows duplicate execution)
if self._task._role and self._task._role.has_run():
# If there is no metadata, the default behavior is to not allow duplicates,
# if there is metadata, check to see if the allow_duplicates flag was set to true
if self._task._role._metadata is None or self._task._role._metadata and not self._task._role._metadata.allow_duplicates:
debug("task belongs to a role which has already run, but does not allow duplicate execution")
return dict(skipped=True, skip_reason='This role has already been run, but does not allow duplicates')
if not self._task.evaluate_conditional(self._job_vars):
debug("when evaulation failed, skipping this task")
return dict(skipped=True, skip_reason='Conditional check failed')
if not self._task.evaluate_tags(self._connection_info.only_tags, self._connection_info.skip_tags):
debug("Tags don't match, skipping this task")
return dict(skipped=True, skip_reason='Skipped due to specified tags')
retries = self._task.retries
if retries <= 0:
retries = 1
delay = self._task.delay
if delay < 0:
delay = 0
debug("starting attempt loop")
result = None
for attempt in range(retries):
if attempt > 0:
# FIXME: this should use the callback mechanism
print("FAILED - RETRYING: %s (%d retries left)" % (self._task, retries-attempt))
result['attempts'] = attempt + 1
debug("running the handler")
result = handler.run(task_vars=self._job_vars)
debug("handler run complete")
if self._task.until:
# TODO: implement until logic (pseudo logic follows...)
# if VariableManager.check_conditional(cond, extra_vars=(dict(result=result))):
# break
pass
elif 'failed' not in result and result.get('rc', 0) == 0:
# if the result is not failed, stop trying
break
if attempt < retries - 1:
time.sleep(delay)
debug("attempt loop complete, returning result")
return result
def _get_connection(self):
'''
Reads the connection property for the host, and returns the
correct connection object from the list of connection plugins
'''
# FIXME: delegate_to calculation should be done here
# FIXME: calculation of connection params/auth stuff should be done here
# FIXME: add all port/connection type munging here (accelerated mode,
# fixing up options for ssh, etc.)? and 'smart' conversion
conn_type = self._connection_info.connection
if conn_type == 'smart':
conn_type = 'ssh'
connection = connection_loader.get(conn_type, self._host, self._connection_info)
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
connection.connect()
return connection
def _get_action_handler(self, connection):
'''
Returns the correct action plugin to handle the requestion task action
'''
if self._task.action in action_loader:
if self._task.async != 0:
raise AnsibleError("async mode is not supported with the %s module" % module_name)
handler_name = self._task.action
elif self._task.async == 0:
handler_name = 'normal'
else:
handler_name = 'async'
handler = action_loader.get(
handler_name,
task=self._task,
connection=connection,
connection_info=self._connection_info,
loader=self._loader
)
if not handler:
raise AnsibleError("the handler '%s' was not found" % handler_name)
return handler

View File

@@ -19,18 +19,191 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class TaskQueueManagerHostPlaybookIterator:
import multiprocessing
import os
import socket
import sys
def __init__(self, host, playbook):
pass
from ansible.errors import AnsibleError
from ansible.executor.connection_info import ConnectionInformation
#from ansible.executor.manager import AnsibleManager
from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.process.result import ResultProcess
from ansible.plugins import callback_loader, strategy_loader
def get_next_task(self):
assert False
from ansible.utils.debug import debug
def is_blocked(self):
# depending on strategy, either
# linear -- all prev tasks must be completed for all hosts
# free -- this host doesnt have any more work to do
assert False
__all__ = ['TaskQueueManager']
class TaskQueueManager:
'''
This class handles the multiprocessing requirements of Ansible by
creating a pool of worker forks, a result handler fork, and a
manager object with shared datastructures/queues for coordinating
work between all processes.
The queue manager is responsible for loading the play strategy plugin,
which dispatches the Play's tasks to hosts.
'''
def __init__(self, inventory, callback, variable_manager, loader, options):
self._inventory = inventory
self._variable_manager = variable_manager
self._loader = loader
self._options = options
# a special flag to help us exit cleanly
self._terminated = False
# create and start the multiprocessing manager
#self._manager = AnsibleManager()
#self._manager.start()
# this dictionary is used to keep track of notified handlers
self._notified_handlers = dict()
# dictionaries to keep track of failed/unreachable hosts
self._failed_hosts = dict()
self._unreachable_hosts = dict()
self._final_q = multiprocessing.Queue()
# FIXME: hard-coded the default callback plugin here, which
# should be configurable.
self._callback = callback_loader.get(callback)
# create the pool of worker threads, based on the number of forks specified
try:
fileno = sys.stdin.fileno()
except ValueError:
fileno = None
self._workers = []
for i in range(self._options.forks):
# duplicate stdin, if possible
new_stdin = None
if fileno is not None:
try:
new_stdin = os.fdopen(os.dup(fileno))
except OSError, e:
# couldn't dupe stdin, most likely because it's
# not a valid file descriptor, so we just rely on
# using the one that was passed in
pass
main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue()
prc = WorkerProcess(self, main_q, rslt_q, loader, new_stdin)
prc.start()
self._workers.append((prc, main_q, rslt_q))
self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start()
def _initialize_notified_handlers(self, handlers):
'''
Clears and initializes the shared notified handlers dict with entries
for each handler in the play, which is an empty array that will contain
inventory hostnames for those hosts triggering the handler.
'''
# Zero the dictionary first by removing any entries there.
# Proxied dicts don't support iteritems, so we have to use keys()
for key in self._notified_handlers.keys():
del self._notified_handlers[key]
# FIXME: there is a block compile helper for this...
handler_list = []
for handler_block in handlers:
handler_list.extend(handler_block.compile())
# then initalize it with the handler names from the handler list
for handler in handler_list:
self._notified_handlers[handler.get_name()] = []
def run(self, play):
'''
Iterates over the roles/tasks in a play, using the given (or default)
strategy for queueing tasks. The default is the linear strategy, which
operates like classic Ansible by keeping all hosts in lock-step with
a given task (meaning no hosts move on to the next task until all hosts
are done with the current task).
'''
connection_info = ConnectionInformation(play, self._options)
self._callback.set_connection_info(connection_info)
# run final validation on the play now, to make sure fields are templated
# FIXME: is this even required? Everything is validated and merged at the
# task level, so else in the play needs to be templated
#all_vars = self._vmw.get_vars(loader=self._dlw, play=play)
#all_vars = self._vmw.get_vars(loader=self._loader, play=play)
#play.post_validate(all_vars=all_vars)
self._callback.playbook_on_play_start(play.name)
# initialize the shared dictionary containing the notified handlers
self._initialize_notified_handlers(play.handlers)
# load the specified strategy (or the default linear one)
strategy = strategy_loader.get(play.strategy, self)
if strategy is None:
raise AnsibleError("Invalid play strategy specified: %s" % play.strategy, obj=play._ds)
# build the iterator
iterator = PlayIterator(inventory=self._inventory, play=play)
# and run the play using the strategy
return strategy.run(iterator, connection_info)
def cleanup(self):
debug("RUNNING CLEANUP")
self.terminate()
self._final_q.close()
self._result_prc.terminate()
for (worker_prc, main_q, rslt_q) in self._workers:
rslt_q.close()
main_q.close()
worker_prc.terminate()
def get_inventory(self):
return self._inventory
def get_callback(self):
return self._callback
def get_variable_manager(self):
return self._variable_manager
def get_loader(self):
return self._loader
def get_server_pipe(self):
return self._server_pipe
def get_client_pipe(self):
return self._client_pipe
def get_pending_results(self):
return self._pending_results
def get_allow_processing(self):
return self._allow_processing
def get_notified_handlers(self):
return self._notified_handlers
def get_workers(self):
return self._workers[:]
def terminate(self):
self._terminated = True

View File

@@ -19,3 +19,39 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.parsing import DataLoader
class TaskResult:
'''
This class is responsible for interpretting the resulting data
from an executed task, and provides helper methods for determining
the result of a given task.
'''
def __init__(self, host, task, return_data):
self._host = host
self._task = task
if isinstance(return_data, dict):
self._result = return_data.copy()
else:
self._result = DataLoader().load(return_data)
def is_changed(self):
return self._check_key('changed')
def is_skipped(self):
return self._check_key('skipped')
def is_failed(self):
return self._check_key('failed') or self._result.get('rc', 0) != 0
def is_unreachable(self):
return self._check_key('unreachable')
def _check_key(self, key):
if 'results' in self._result:
flag = False
for res in self._result.get('results', []):
flag |= res.get(key, False)
else:
return self._result.get(key, False)