Fireball2 mode working!

This commit is contained in:
James Cammarata
2013-08-11 00:41:18 -05:00
parent acc5d09351
commit 521e14a3ad
6 changed files with 399 additions and 155 deletions

View File

@@ -312,7 +312,7 @@ class PlayBook(object):
conditional=task.only_if, callbacks=self.runner_callbacks,
sudo=task.sudo, sudo_user=task.sudo_user,
transport=task.transport, sudo_pass=task.sudo_pass, is_playbook=True,
check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args,
check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args, accelerate=task.play.accelerate,
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR
)

View File

@@ -29,7 +29,7 @@ class Play(object):
__slots__ = [
'hosts', 'name', 'vars', 'vars_prompt', 'vars_files',
'handlers', 'remote_user', 'remote_port',
'handlers', 'remote_user', 'remote_port', 'accelerate',
'sudo', 'sudo_user', 'transport', 'playbook',
'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks',
'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct'
@@ -39,7 +39,7 @@ class Play(object):
# and don't line up 1:1 with how they are stored
VALID_KEYS = [
'hosts', 'name', 'vars', 'vars_prompt', 'vars_files',
'tasks', 'handlers', 'user', 'port', 'include',
'tasks', 'handlers', 'user', 'port', 'include', 'accelerate',
'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial',
'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage'
]
@@ -103,6 +103,7 @@ class Play(object):
self.gather_facts = ds.get('gather_facts', None)
self.remote_port = self.remote_port
self.any_errors_fatal = ds.get('any_errors_fatal', False)
self.accelerate = ds.get('accelerate', False)
self.max_fail_pct = int(ds.get('max_fail_percentage', 100))
load_vars = {}

View File

@@ -138,7 +138,8 @@ class Runner(object):
diff=False, # whether to show diffs for template files that change
environment=None, # environment variables (as dict) to use inside the command
complex_args=None, # structured data in addition to module_args, must be a dict
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR, # ex. False
accelerate=False, # use fireball acceleration
):
# used to lock multiprocess inputs and outputs at various levels
@@ -179,11 +180,16 @@ class Runner(object):
self.environment = environment
self.complex_args = complex_args
self.error_on_undefined_vars = error_on_undefined_vars
self.accelerate = accelerate
self.callbacks.runner = self
# if the transport is 'smart' see if SSH can support ControlPersist if not use paramiko
# 'smart' is the default since 1.2.1/1.3
if self.transport == 'smart':
if self.accelerate:
# if we're using accelerated mode, force the local
# transport to fireball2
self.transport = "fireball2"
elif self.transport == 'smart':
# if the transport is 'smart' see if SSH can support ControlPersist if not use paramiko
# 'smart' is the default since 1.2.1/1.3
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(out, err) = cmd.communicate()
if "Bad configuration option" in err:

View File

@@ -19,7 +19,9 @@ import json
import os
import base64
import socket
import struct
from ansible.callbacks import vvv
from ansible.runner.connection_plugins.ssh import Connection as SSHConnection
from ansible import utils
from ansible import errors
from ansible import constants
@@ -27,32 +29,68 @@ from ansible import constants
class Connection(object):
''' raw socket accelerated connection '''
def __init__(self, runner, host, port, *args, **kwargs):
def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs):
self.ssh = SSHConnection(
runner=runner,
host=host,
port=port,
user=user,
password=password,
private_key_file=private_key_file
)
self.runner = runner
self.host = host
self.context = None
self.conn = None
self.key = utils.key_for_hostname(host)
self.fbport = constants.FIREBALL2_PORT
self.is_connected = False
# attempt to work around shared-memory funness
if getattr(self.runner, 'aes_keys', None):
utils.AES_KEYS = self.runner.aes_keys
self.host = host
self.context = None
self.conn = None
self.cipher = AES256Cipher()
def _execute_fb_module(self):
args = "password=%s" % base64.b64encode(self.key.__str__())
self.ssh.connect()
return self.runner._execute_module(self.ssh, "/root/.ansible/tmp", 'fireball2', args, inject={"password":self.key})
if port is None:
self.port = constants.FIREBALL2_PORT
else:
self.port = port
def connect(self):
def connect(self, allow_ssh=True):
''' activates the connection object '''
self.conn = socket.socket()
self.conn.connect((self.host,self.port))
if self.is_connected:
return self
try:
self.conn = socket.socket()
self.conn.connect((self.host,self.fbport))
except:
if allow_ssh:
print "Falling back to ssh to startup accelerated mode"
res = self._execute_fb_module()
return self.connect(allow_ssh=False)
else:
raise errors.AnsibleError("Failed to connect to %s:%s" % (self.host,self.fbport))
self.is_connected = True
return self
def send_data(self, data):
packed_len = struct.pack('Q',len(data))
return self.conn.sendall(packed_len + data)
def recv_data(self):
header_len = 8 # size of a packed unsigned long long
data = b""
while len(data) < header_len:
data += self.conn.recv(1024)
data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
data += self.conn.recv(1024)
return data
def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'):
''' run a command on the remote host '''
@@ -65,12 +103,12 @@ class Connection(object):
executable=executable,
)
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnisbleError("Failed to send command to %s:%s" % (self.host,self.port))
response = self.conn.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, response)
response = utils.parse_json(response)
return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr',''))
@@ -83,18 +121,18 @@ class Connection(object):
if not os.path.exists(in_path):
raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path)
data = base64.file(in_path).read()
data = file(in_path).read()
data = base64.b64encode(data)
data = dict(mode='put', data=data, out_path=out_path)
# TODO: support chunked file transfer
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnsibleError("failed to send the file to %s:%s" % (self.host,self.port))
response = self.conn.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, data)
response = utils.parse_json(response)
# no meaningful response needed for this
@@ -105,12 +143,12 @@ class Connection(object):
data = dict(mode='fetch', in_path=in_path)
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnsibleError("failed to initiate the file fetch with %s:%s" % (self.host,self.port))
response = self.socket.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, data)
response = utils.parse_json(response)
response = response['data']
response = base64.b64decode(response)

View File

@@ -31,7 +31,6 @@ import ansible.constants as C
import time
import StringIO
import stat
import string
import termios
import tty
import pipes
@@ -41,11 +40,6 @@ import warnings
import traceback
import getpass
import hmac
from Crypto.Cipher import
from Crypto import Random
from Crypto.Random.random import StrongRandom
VERBOSITY=0
MAX_FILE_SIZE_FOR_DIFF=1*1024*1024
@@ -57,10 +51,8 @@ except ImportError:
try:
from hashlib import md5 as _md5
from hashlib import sha1 as _sha1
except ImportError:
from md5 import md5 as _md5
from sha1 import sha1 as _sha1
PASSLIB_AVAILABLE = False
try:
@@ -69,128 +61,51 @@ try:
except:
pass
KEYCZAR_AVAILABLE=False
try:
import keyczar.errors as key_errors
from keyczar.keys import AesKey
KEYCZAR_AVAILABLE=True
except ImportError:
pass
###############################################################
# Abstractions around PyCrypto
# Abstractions around keyczar
###############################################################
class AES256Cipher(object):
"""
Class abstraction of an AES 256 cipher. This class
also keeps track of the time since the key was last
generated, so you know when to rekey. Rekeying would
be done as follows:
def key_for_hostname(hostname):
# fireball mode is an implementation of ansible firing up zeromq via SSH
# to use no persistent daemons or key management
k = AES256Cipher.gen_key()
<exchange new key with client securely>
AES26Cipher.set_key(k)
if not KEYCZAR_AVAILABLE:
raise errors.AnsibleError("python-keyczar must be installed to use fireball mode")
From this point on the new key would be used until
the lifetime is exceeded.
"""
def __init__(self, lifetime=60*30, mode=AES.MODE_CFB):
self.lifetime = lifetime
self.mode = mode
self.set_key(self.gen_key())
key_path = os.path.expanduser("~/.fireball.keys")
if not os.path.exists(key_path):
os.makedirs(key_path)
key_path = os.path.expanduser("~/.fireball.keys/%s" % hostname)
def gen_key(self):
"""
Generates a 256-bit (32 byte) key to be used for the
AES block encryption.
"""
return b"".join(StrongRandom().sample(string.letters+string.digits+string.punctuation,32))
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate()
fh = open(key_path, "w")
fh.write(str(key))
fh.close()
return key
else:
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
def set_key(self,key):
"""
Sets the internal key to the one provided and resets the
internal time to now. This key should ONLY be set to one
generated by gen_key()
"""
self.init_time = time.time()
self.key = key
def encrypt(key, msg):
return key.Encrypt(msg)
def should_rekey(self):
"""
Returns true if the lifetime of the current key has
exceeded the set lifetime.
"""
if (time.time() - self.init_time) > self.lifetime:
return True
else:
return False
def _pad(self, msg):
"""
Adds padding to the message so that it is a full
AES block size. Used during encryption of the message.
"""
pad = AES.block_size - len(msg) % AES.block_size
return msg + pad * chr(pad)
def _unpad(self, msg):
"""
Strips out the padding that _pad added. Used during
the decryption of the message.
"""
pad = ord(msg[-1])
return msg[:-pad]
def gen_sig(self, msg):
"""
Generates an HMAC-SHA1 signature for the message
"""
return hmac.new(self.key, msg, _sha1).digest()
def validate_sig(self, msg, sig):
"""
Verifies the generated signature of the message matches
the signature provided.
"""
new_sig = self.gen_sig(msg)
return (new_sig == sig)
def encrypt(self, msg):
"""
Encrypt the message using AES. The signature
is appended to the end of the message and is
used to verify the integrity of the IV and data.
Returns a base64-encoded version of the following:
rval[0:16] = initialization vector
rval[16:-20] = cipher text
rval[-20:] = signature
"""
msg = self._pad(msg)
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, self.mode, iv)
data = iv + cipher.encrypt(msg)
sig = self.gen_sig(data)
return (data + sig).encode('base64')
def decrypt(self, msg):
"""
Decrypt the message using AES. The signature is
used to verify the IV and data before decoding to
ensure the integrity of the message. This is an
HMAC-SHA1 hash, so it is always 20 characters
The incoming message format (after base64 decoding)
is as follows:
msg[0:16] = initialization vector
msg[16:-20] = cipher text
msg[-20:] = signature (HMAC-SHA1)
Returns the plain-text of the cipher.
"""
msg = msg.decode('base64')
data = msg[0:-20] # iv + cipher text
msig = msg[-20:] # hmac-sha1 hash
if not self.validate_sig(data,msig):
raise Exception("Failed to validate the message signature")
iv = msg[:AES.block_size]
cipher = AES.new(self.key, self.mode, iv)
return self._unpad(cipher.decrypt(msg)[AES.block_size:])
def decrypt(key, msg):
try:
return key.Decrypt(msg)
except key_errors.InvalidSignatureError:
raise errors.AnsibleError("decryption failed")
###############################################################
# UTILITY FUNCTIONS FOR COMMAND LINE TOOLS