#!/usr/bin/env python

# (c) 2016, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

########################################################
from __future__ import (absolute_import, division, print_function)

__metaclass__ = type
__requires__ = ['ansible']

try:
    import pkg_resources
except Exception:
    pass

import fcntl
import os
import shlex
import signal
import socket
import struct
import sys
import time
import traceback
import syslog
import datetime
import logging

from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display


def do_fork():
    '''
    Does the required double fork for a daemon process. Based on
    http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
    '''
    try:
        pid = os.fork()
        if pid > 0:
            return pid

        #os.chdir("/")
        os.setsid()
        os.umask(0)

        try:
            pid = os.fork()
            if pid > 0:
                sys.exit(0)

            if C.DEFAULT_LOG_PATH != '':
                out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
                err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
            else:
                out_file = open('/dev/null', 'ab+')
                err_file = open('/dev/null', 'ab+', 0)

            os.dup2(out_file.fileno(), sys.stdout.fileno())
            os.dup2(err_file.fileno(), sys.stderr.fileno())
            os.close(sys.stdin.fileno())

            return pid
        except OSError as e:
            sys.exit(1)
    except OSError as e:
        sys.exit(1)

def send_data(s, data):
    packed_len = struct.pack('!Q', len(data))
    return s.sendall(packed_len + data)

def recv_data(s):
    header_len = 8 # size of a packed unsigned long long
    data = b""
    while len(data) < header_len:
        d = s.recv(header_len - len(data))
        if not d:
            return None
        data += d
    data_len = struct.unpack('!Q', data[:header_len])[0]
    data = data[header_len:]
    while len(data) < data_len:
        d = s.recv(data_len - len(data))
        if not d:
            return None
        data += d
    return data


class Server():

    def __init__(self, path, play_context):

        self.path = path
        self.play_context = play_context

        display.display(
            'creating new control socket for host %s:%s as user %s' %
            (play_context.remote_addr, play_context.port, play_context.remote_user),
            log_only=True
        )

        display.display('control socket path is %s' % path, log_only=True)
        display.display('current working directory is %s' % os.getcwd(), log_only=True)

        self._start_time = datetime.datetime.now()

        display.display("using connection plugin %s" % self.play_context.connection, log_only=True)

        self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
        self.conn._connect()
        if not self.conn.connected:
            raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr)

        connection_time = datetime.datetime.now() - self._start_time
        display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True)

        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.bind(path)
        self.socket.listen(1)

        signal.signal(signal.SIGALRM, self.alarm_handler)

    def dispatch(self, obj, name, *args, **kwargs):
        meth = getattr(obj, name, None)
        if meth:
            return meth(*args, **kwargs)

    def alarm_handler(self, signum, frame):
        '''
        Alarm handler
        '''
        # FIXME: this should also set internal flags for other
        #        areas of code to check, so they can terminate
        #        earlier than the socket going back to the accept
        #        call and failing there.
        #
        # hooks the connection plugin to handle any cleanup
        self.dispatch(self.conn, 'alarm_handler', signum, frame)
        self.socket.close()

    def run(self):
        try:
            while True:
                # set the alarm, if we don't get an accept before it
                # goes off we exit (via an exception caused by the socket
                # getting closed while waiting on accept())
                # FIXME: is this the best way to exit? as noted above in the
                #        handler we should probably be setting a flag to check
                #        here and in other parts of the code
                signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
                try:
                    (s, addr) = self.socket.accept()
                    display.display('incoming request accepted on persistent socket', log_only=True)
                    # clear the alarm
                    # FIXME: potential race condition here between the accept and
                    #        time to this call.
                    signal.alarm(0)
                except:
                    break

                while True:
                    data = recv_data(s)
                    if not data:
                        break

                    signal.alarm(self.play_context.timeout)

                    rc = 255
                    try:
                        if data.startswith(b'EXEC: '):
                            display.display("socket operation is EXEC", log_only=True)
                            cmd = data.split(b'EXEC: ')[1]
                            (rc, stdout, stderr) = self.conn.exec_command(cmd)
                        elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
                            (op, src, dst) = shlex.split(to_native(data))
                            stdout = stderr = ''
                            try:
                                if op == 'FETCH:':
                                    display.display("socket operation is FETCH", log_only=True)
                                    self.conn.fetch_file(src, dst)
                                elif op == 'PUT:':
                                    display.display("socket operation is PUT", log_only=True)
                                    self.conn.put_file(src, dst)
                                rc = 0
                            except:
                                pass
                        elif data.startswith(b'CONTEXT: '):
                            display.display("socket operation is CONTEXT", log_only=True)
                            pc_data = data.split(b'CONTEXT: ', 1)[1]

                            if PY3:
                                pc_data = cPickle.loads(pc_data, encoding='bytes')
                            else:
                                pc_data = cPickle.loads(pc_data)

                            pc = PlayContext()
                            pc.deserialize(pc_data)

                            self.dispatch(self.conn, 'update_play_context', pc)
                            continue
                        else:
                            display.display("socket operation is UNKNOWN", log_only=True)
                            stdout = ''
                            stderr = 'Invalid action specified'
                    except:
                        stdout = ''
                        stderr = traceback.format_exc()

                    signal.alarm(0)

                    display.display("socket operation completed with rc %s" % rc, log_only=True)

                    send_data(s, to_bytes(rc))
                    send_data(s, to_bytes(stdout))
                    send_data(s, to_bytes(stderr))
                s.close()
        except Exception as e:
            display.display(traceback.format_exc(), log_only=True)
        finally:
            # when done, close the connection properly and cleanup
            # the socket file so it can be recreated
            end_time = datetime.datetime.now()
            delta = end_time - self._start_time
            display.display('shutting down control socket, connection was active for %s secs' % delta, log_only=True)
            try:
                self.conn.close()
                self.socket.close()
            except Exception as e:
                pass
            os.remove(self.path)

def main():
    # Need stdin as a byte stream
    if PY3:
        stdin = sys.stdin.buffer
    else:
        stdin = sys.stdin

    try:
        # read the play context data via stdin, which means depickling it
        # FIXME: as noted above, we will probably need to deserialize the
        #        connection loader here as well at some point, otherwise this
        #        won't find role- or playbook-based connection plugins
        cur_line = stdin.readline()
        init_data = b''
        while cur_line.strip() != b'#END_INIT#':
            if cur_line == b'':
                raise Exception("EOF found before init data was complete")
            init_data += cur_line
            cur_line = stdin.readline()
        if PY3:
            pc_data = cPickle.loads(init_data, encoding='bytes')
        else:
            pc_data = cPickle.loads(init_data)

        pc = PlayContext()
        pc.deserialize(pc_data)
    except Exception as e:
        # FIXME: better error message/handling/logging
        sys.stderr.write(traceback.format_exc())
        sys.exit("FAIL: %s" % e)

    ssh = connection_loader.get('ssh', class_only=True)
    m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user)

    # create the persistent connection dir if need be and create the paths
    # which we will be using later
    tmp_path = unfrackpath("$HOME/.ansible/pc")
    makedirs_safe(tmp_path)
    lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
    sf_path = unfrackpath(m % dict(directory=tmp_path))

    # if the socket file doesn't exist, spin up the daemon process
    lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
    fcntl.lockf(lock_fd, fcntl.LOCK_EX)
    if not os.path.exists(sf_path):
        pid = do_fork()
        if pid == 0:
            rc = 0
            try:
                server = Server(sf_path, pc)
            except AnsibleConnectionFailure as exc:
                display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True)
                display.display(str(exc), log_only=True)
                rc = 1
            except Exception as exc:
                display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True)
                display.display(traceback.format_exc(), log_only=True)
                rc = 1
            fcntl.lockf(lock_fd, fcntl.LOCK_UN)
            os.close(lock_fd)
            if rc == 0:
                server.run()
            sys.exit(rc)
    else:
        display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True)
    fcntl.lockf(lock_fd, fcntl.LOCK_UN)
    os.close(lock_fd)

    # now connect to the daemon process
    # FIXME: if the socket file existed but the daemonized process was killed,
    #        the connection will timeout here. Need to make this more resilient.
    rc = 0
    while rc == 0:
        data = stdin.readline()
        if data == b'':
            break
        if data.strip() == b'':
            continue
        sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        attempts = 1
        while True:
            try:
                sf.connect(sf_path)
                break
            except socket.error:
                # FIXME: better error handling/logging/message here
                time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
                attempts += 1
                if attempts > C.PERSISTENT_CONNECT_RETRIES:
                    display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True)
                    display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True)
                    sys.stderr.write('failed to connect to control socket')
                    sys.exit(255)

        # send the play_context back into the connection so the connection
        # can handle any privilege escalation activities
        pc_data = b'CONTEXT: %s' % init_data
        send_data(sf, pc_data)

        send_data(sf, data.strip())

        rc = int(recv_data(sf), 10)
        stdout = recv_data(sf)
        stderr = recv_data(sf)

        sys.stdout.write(to_native(stdout))
        sys.stderr.write(to_native(stderr))

        sf.close()
        break

    sys.exit(rc)

if __name__ == '__main__':
    display = Display()
    main()
