#!/usr/bin/env python

# (c) 2016, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

########################################################
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

__requires__ = ['ansible']
try:
    import pkg_resources
except Exception:
    pass

import fcntl
import hashlib
import os
import shlex
import signal
import socket
import struct
import sys
import time
import traceback

from io import BytesIO

from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe

def do_fork():
    '''
    Does the required double fork for a daemon process. Based on
    http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
    '''
    try:
        pid = os.fork()
        if pid > 0:
            return pid

        os.chdir("/")
        os.setsid()
        os.umask(0)

        try:
            pid = os.fork()
            if pid > 0:
                sys.exit(0)

            os.close(sys.stdin.fileno())
            os.close(sys.stdout.fileno())
            os.close(sys.stderr.fileno())

            return pid
        except OSError as e:
            sys.exit(1)
    except OSError as e:
        sys.exit(1)

def send_data(s, data):
    packed_len = struct.pack('!Q',len(data))
    return s.sendall(packed_len + data)

def recv_data(s):
    header_len = 8 # size of a packed unsigned long long
    data = b""
    while len(data) < header_len:
        d = s.recv(header_len - len(data))
        if not d:
            return None
        data += d
    data_len = struct.unpack('!Q',data[:header_len])[0]
    data = data[header_len:]
    while len(data) < data_len:
        d = s.recv(data_len - len(data))
        if not d:
            return None
        data += d
    return data

class Server():
    def __init__(self, path, play_context):
        self.path = path
        self.play_context = play_context

        # FIXME: the connection loader here is created brand new,
        #        so it will not see any custom paths loaded (ie. via
        #        roles), so we will need to serialize the connection
        #        loader and send it over as we do the PlayContext
        #        in the main() method below.
        self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
        self.conn._connect()

        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.bind(path)
        self.socket.listen(1)

        signal.signal(signal.SIGALRM, self.alarm_handler)

    def alarm_handler(self, signum, frame):
        '''
        Alarm handler
        '''
        # FIXME: this should also set internal flags for other
        #        areas of code to check, so they can terminate
        #        earlier than the socket going back to the accept
        #        call and failing there.
        self.socket.close()

    def run(self):
        try:
            while True:
                # set the alarm, if we don't get an accept before it
                # goes off we exit (via an exception caused by the socket
                # getting closed while waiting on accept())
                # FIXME: is this the best way to exit? as noted above in the
                #        handler we should probably be setting a flag to check
                #        here and in other parts of the code
                signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
                try:
                    (s, addr) = self.socket.accept()
                    # clear the alarm
                    # FIXME: potential race condition here between the accept and
                    #        time to this call.
                    signal.alarm(0)
                except:
                    break

                while True:
                    data = recv_data(s)
                    if not data:
                        break

                    rc = 255
                    try:
                        if data.startswith(b'EXEC: '):
                            cmd = data.split(b'EXEC: ')[1]
                            (rc, stdout, stderr) = self.conn.exec_command(cmd)
                        elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
                            (op, src, dst) = shlex.split(to_native(data))
                            stdout = stderr = ''
                            try:
                                if op == 'FETCH:':
                                    self.conn.fetch_file(src, dst)
                                elif op == 'PUT:':
                                    self.conn.put_file(src, dst)
                                rc = 0
                            except:
                                pass
                        else:
                            stdout = ''
                            stderr = 'Invalid action specified'
                    except:
                        stdout = ''
                        stderr = traceback.format_exc()

                    send_data(s, to_bytes(str(rc)))
                    send_data(s, to_bytes(stdout))
                    send_data(s, to_bytes(stderr))
                s.close()
        except Exception as e:
            # FIXME: proper logging and error handling here
            print("run exception: %s" % e)
            print(traceback.format_exc())
        finally:
            # when done, close the connection properly and cleanup
            # the socket file so it can be recreated
            try:
                self.conn.close()
            except Exception as e:
                pass
            os.remove(self.path)

def main():
    try:
        # read the play context data via stdin, which means depickling it
        # FIXME: as noted above, we will probably need to deserialize the
        #        connection loader here as well at some point, otherwise this
        #        won't find role- or playbook-based connection plugins
        cur_line = sys.stdin.readline()
        init_data = ''
        while cur_line.strip() != '#END_INIT#':
            if cur_line  == '':
                raise Exception("EOL found before init data was complete")
            init_data += cur_line
            cur_line = sys.stdin.readline()
        src = BytesIO(to_bytes(init_data))
        pc_data = cPickle.load(src)
        src.close()

        pc = PlayContext()
        pc.deserialize(pc_data)
    except Exception as e:
        # FIXME: better error message/handling/logging
        print("FAIL: %s" % e)
        print(traceback.format_exc())
        sys.exit(1)

    # here we create a hash to use later when creating the socket file,
    # so we can hide the info about the target host/user/etc.
    m = hashlib.sha256()
    for attr in ('connection', 'remote_addr', 'port', 'remote_user'):
        val = getattr(pc, attr, None)
        if val:
            m.update(to_bytes(val))

    # create the persistent connection dir if need be and create the paths
    # which we will be using later
    tmp_path = unfrackpath("$HOME/.ansible/pc")
    makedirs_safe(tmp_path)
    lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
    sf_path = unfrackpath("%s/conn-%s" % (tmp_path, m.hexdigest()[0:12]))

    # if the socket file doesn't exist, spin up the daemon process
    lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
    fcntl.lockf(lock_fd, fcntl.LOCK_EX)
    if not os.path.exists(sf_path):
        pid = do_fork()
        if pid == 0:
           server = Server(sf_path, pc)
           fcntl.lockf(lock_fd, fcntl.LOCK_UN)
           os.close(lock_fd)
           server.run()
           sys.exit(0)
    fcntl.lockf(lock_fd, fcntl.LOCK_UN)
    os.close(lock_fd)

    # now connect to the daemon process
    # FIXME: if the socket file existed but the daemonized process was killed,
    #        the connection will timeout here. Need to make this more resilient.
    rc = 0
    while rc == 0:
        data = sys.stdin.readline()
        if data == '':
            break
        if data.strip() == '':
            continue
        sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        attempts = 1
        while True:
            try:
                sf.connect(sf_path)
                break
            except socket.error:
                # FIXME: better error handling/logging/message here
                # FIXME: make # of retries configurable?
                time.sleep(0.1)
                attempts += 1
                if attempts > 10:
                    sys.stderr.write("failed to connect to the host, connection timeout\n")
                    sys.exit(255)

        send_data(sf, to_bytes(data.strip()))
        rc = int(recv_data(sf), 10)
        stdout = recv_data(sf)
        stderr = recv_data(sf)
        sys.stdout.write(to_native(stdout))
        sys.stderr.write(to_native(stderr))
        #sys.stdout.flush()
        #sys.stderr.flush()

        sf.close()
        break
    sys.exit(rc)

if __name__ == '__main__':
    main()
