#!/usr/bin/env python
#------------------------------------------------------------------------------
# Copyright 2008-2012 Istituto Nazionale di Fisica Nucleare (INFN)
#
# Licensed under the EUPL, Version 1.1 only (the "Licence").
# You may not use this work except in compliance with the Licence.
# You may obtain a copy of the Licence at:
#
# http://joinup.ec.europa.eu/system/files/EN/EUPL%20v.1.1%20-%20Licence.pdf
#
# Unless required by applicable law or agreed to in
# writing, software distributed under the Licence is
# distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
# either express or implied.
# See the Licence for the specific language governing
# permissions and limitations under the Licence.
#------------------------------------------------------------------------------
"""
WNoDeS pre-execution script
"""

import socket
import commands
import os
import sys
import time
import pickle
import threading
import optparse
import ConfigParser


class arbitrary_dict(dict):
    """A dictionary which applies an arbitrary
    key-altering function before accessing the keys"""

    def __keytransform__(self, key):
        return key

    # Overriden methods. List from a
    # http://stackoverflow.com/questions/2390827/how-to-properly-subclass-dict
    def __init__(self, *args, **kwargs):
        self.update(*args, **kwargs)

    def __getitem__(self, key):
        return super(arbitrary_dict, self).__getitem__(self.__keytransform__(key))

    def __setitem__(self, key, value):
        return super(arbitrary_dict, self).__setitem__(self.__keytransform__(key), value)

    def __delitem__(self, key):
        return super(arbitrary_dict, self).__delitem__(self.__keytransform__(key))

    def __contains__(self, key):
        return super(arbitrary_dict, self).__contains__(self.__keytransform__(key))

    def update(self, *args, **kwargs):
        trans = self.__keytransform__
        super(arbitrary_dict, self).update(*[(trans(k), v) for k,v in args], **dict((trans(k), kwargs[k]) for k in kwargs))


class lcdict(arbitrary_dict):

    def __keytransform__(self, key):
        return str(key).lower()


class Connection(threading.Thread):
    """ Provide a socket-threaded connection using WNoDeS syntax. """

    # lock used to rotate logfiles
    LOCK_LOG = threading.Lock()

    def __init__(self, host, port):
        threading.Thread.__init__(self)
        self.host = host
        self.port = int(port)

    def sendRequest(self, msg):
        """
        Send a msg to a WNoDeS TCP socket server.
        Msg format MUST BE a dictionary with one value-pair attribute.
        Key  MUST BE the method name you want to execute on the NameServer.
        Value MUST BE the options method weather they exist.
        If msg is not in this format an error is raised

        It returns a tuple with this format (status, data).
        """
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((self.host, self.port))
            request = self._serialize_input(msg)
            sock.sendall(request + '\n')
            socketfile = sock.makefile('r', 0)
            response = socketfile.readline().strip()
            sock.shutdown(2)
            sock.close()
            return self._marshall_output(response)
        except socket.error:
            raise
        except:
            raise

    def _serialize_input(self, msg):
        try:
            if len(msg.keys()) == 1:
                msg = str(pickle.dumps(msg).replace('\n', '1%2'))
                return msg
            else:
                print ('msg len(%s) is not 1 : not formatted as expected'
                                % len(msg.keys()))
                #self.updateLog('%s: msg len(%s) is not 1: '
                #                % (wnodes_utils.whoami(), len(msg.keys())),
                #                   "error") +
                #               'not formatted as expected'
                response = (1, None)
                return response
        except pickle.PicklingError:
            sys.exit('Data received cannot be loaded')
        except:
            print ('%s' % str(sys.exc_info()[0]))
            #self.updateLog('%s: %s' % (wnodes_utils.whoami(),
            #                           str(sys.exc_info()[0])), "error")
            response = (1, '')
            return response

    def _marshall_output(self, response):
        try:
            output = pickle.loads(response.replace('1%2', '\n'))

        except pickle.UnpicklingError:
            sys.exit('Data received cannot be loaded: Unpickling error')
        except Exception:
            exctype, value = sys.exc_info()[:2]
            sys.exit('Data received cannot be loaded for this reason: %s %s'
                     % (exctype, value))
        return output


def whoIs_TheBait(PX_CONFIG, USERNAME):
    try:
        for line in open(PX_CONFIG['TMPFILE']):
            if job_id in line:
                my_bait = line.strip().split("%")[2]
                return [0, my_bait]
        else:
            raise IOError
    except IOError:
        # jump here if file not found or if the job ID cannot be read
        try:
            c = Connection(PX_CONFIG['NS_HOST'], PX_CONFIG['NS_PORT'])
            msg = {'whoIs_TheBait': [hostname]}
            OUTPUT = c.sendRequest(msg)
            my_bait = str(OUTPUT[1])
            if OUTPUT[0] == 0:
            # there is bait so this is a job running on host supporting WNoDeS
                flog = open(PX_CONFIG['TMPFILE'], 'a')
                flog.write('BAIT%' + job_id + '%' + my_bait + '\n')
                flog.close()
                os.chmod(PX_CONFIG['TMPFILE'], 0777)
            return [OUTPUT[0], my_bait]
        except:
            exctype, value = sys.exc_info()[:2]
            err_msg = "whoIs_TheBait fails: host %s, port: %s, \
            error %s: %s" % (PX_CONFIG['NS_HOST'], PX_CONFIG['NS_PORT'],
                             exctype, value)
            debug(USERNAME, err_msg)
            sys.exit(PX_CONFIG['FAIL_RETURN_STATUS'])


def retriveName(string):
    a = []
    r = 'no'
    for c in string:
        if c == ')':
            r = 'no'
        elif c == '(':
            r = 'yes'
        elif r == 'yes':
            a.append(c)

    return ''.join(a)


    # Usefull for debug only
def debug(username, debug_msg):

    flog = open('/tmp/%s_debug_px_wnodes' % username, 'a')
    flog.write('%s - %s \n' % (time.asctime(), str(debug_msg)))
    flog.close()


def getDefaultConfig(config_file, USERNAME):

    PX_MANDATORY_PARAM = ['TMPFILE', 'LOCAL_DOMAIN',
                          'NS_HOST', 'NS_PORT', 'FAIL_RETURN_STATUS']

    if not config_file == None and os.path.isfile(config_file):

        try:
            px_conf_file = ConfigParser.RawConfigParser()
            px_conf_file._sections = lcdict()
            px_conf_file._defaults = lcdict()
            px_conf_file.read(config_file)
        except (ConfigParser.MissingSectionHeaderError,
                ConfigParser.ParsingError):
            exctype, value = sys.exc_info()[:2]
            debug(USERNAME, 'There are errors %s; %s' % (exctype, value))
            sys.exit(3)

    else:
        debug(USERNAME, 'Config file %s is not available' % str(config_file))
        sys.exit(3)

    # Read parameter from GENERAL section

    if px_conf_file.has_section('general'):
        PX_CONFIG = {}
        for item in px_conf_file.items('general'):
            PX_CONFIG[item[0].upper()] = item[1]

        # Check whether all mandatory parameters have been configured or not
        for param in PX_MANDATORY_PARAM:
            if not PX_CONFIG.has_key(param):
                debug(USERNAME, 'Missing mandatory PARAM %s in the GENERAL section' % param)
                sys.exit(3)

        return PX_CONFIG

    else:
        debug(USERNAME, 'Missing mandatory section CONFIG in the configuration file')
        sys.exit(3)


def getVmConfig(CONFIG_FILE, USER_DETAILS):

    VM_MANDATORY_PARAM = ['MEM', 'CPU', 'STORAGE', 'BANDWIDTH', 'IMG',
                          'TYPE', 'NETWORK_TYPE', 'ENABLEVIRTIO', 'PX_SCRIPT']

    if not CONFIG_FILE == None and os.path.isfile(CONFIG_FILE):

        px_conf_file = ConfigParser.RawConfigParser()
        px_conf_file._sections = lcdict()
        px_conf_file._defaults = lcdict()
        px_conf_file.read(CONFIG_FILE)

        MY_SECTION = 'default'

        for k in USER_DETAILS:
            if px_conf_file.has_section(k.lower()):
                MY_SECTION = k.lower()

        vmParameters = {}
        for item in px_conf_file.items(MY_SECTION):
            vmParameters[item[0].upper()] = item[1]

        # Check whether all mandatory parameters have been configured or not
        if vmParameters['TYPE'].upper() == 'BATCH_REAL':
            VM_MANDATORY_PARAM = ['MEM', 'CPU', 'STORAGE', 'BANDWIDTH',
                                  'TYPE', 'PX_SCRIPT']

        for param in VM_MANDATORY_PARAM:
            if not vmParameters.has_key(param):
                debug(USER_DETAILS[0], 'Missing mandatory PARAM %s in the vmParameters section' % param)
                sys.exit(3)

        return vmParameters


def isOpen(host, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(1)
    try:
        s.connect((host, port))
    except Exception:
        status = False
    else:
        status = True
    s.close()
    return status

vmParameters = {}
PX_CONFIG = {}

hostname = socket.gethostname().split('.')[0]

usage_str = 'usage: %prog [options]'
optionArgs = optparse.OptionParser(usage=usage_str)


optionArgs.add_option('-f', '--file',
                      action='store', type='string', dest='config_file',
                      help='full path of pre execution script configuration file')

optionArgs.add_option('-j', '--jobid',
                      action='store', type='string', dest='job_id',
                      help='Only for PBS. Specify the jobid')

optionArgs.add_option('-u', '--username',
                      action='store', type='string', dest='username',
                      help='Only for PBS. Specify the username')

(options, args) = optionArgs.parse_args()


BATCH_SYSTEM = ''
OUTPUT = {}

# discover the batch system type
LSID = commands.getstatusoutput('source /etc/profile.d/lsf.sh;lsid')
if LSID[0] == 0:
    job_id = os.environ.get("LSB_JOBID")
    CPU = os.environ.get("LSB_MCPU_HOSTS")
    USERNAME = os.environ.get("USER")
    BATCH_SYSTEM = 'LSF'
    OUTPUT[0] = 0
else:
    job_id = options.job_id
    CPU = '000 111'
    USERNAME = options.username
    BATCH_SYSTEM = 'PBS'
    OUTPUT[0] = 3

# parallel jobs are not supported. PX must fail in case of parallel jobs
if len(CPU.split()) > 2:
    sys.exit(3)

# Obtain groups of belonging
ID_OUTPUT = commands.getoutput('id ' + USERNAME).split()
GROUPS = ()
GROUPS = GROUPS + (USERNAME,)
try:
    groups = ID_OUTPUT[2].split(',')
except:
    groups = ID_OUTPUT[2]

for group in groups:
    retriveName(group)
    GROUPS = GROUPS + (retriveName(group),)

PX_CONFIG = getDefaultConfig(options.config_file, USERNAME)
vmParameters = getVmConfig(options.config_file, GROUPS)
vmParameters['BATCH_JOBID'] = job_id
vmParameters['USER_DETAILS'] = GROUPS

try:
    c = Connection(PX_CONFIG['NS_HOST'], int(PX_CONFIG['NS_PORT']))
    msg = {'get_config_from_file': ['BAIT']}
    OUTPUT = c.sendRequest(msg)
    BAIT_PORT = OUTPUT[1]['BAIT_PORT']
except:
    exctype, value = sys.exc_info()[:2]
    err_msg = "failed to contact wnodes ns at the following address: host %s, port: %s, \
               error %s: %s" % (PX_CONFIG['NS_HOST'], PX_CONFIG['NS_PORT'], exctype, value)
    debug(USERNAME, err_msg)

    # try to get bait port from the local config file
    try:
        BAIT_PORT = PX_CONFIG['BAIT_PORT']
    except KeyError:
        debug(USERNAME, 'BAIT_PORT is not locally defined in the px conf file')
        sys.exit(PX_CONFIG['FAIL_RETURN_STATUS'])


HOST_IS_BAIT = isOpen('127.0.0.1', int(BAIT_PORT))

if HOST_IS_BAIT:

    # Try to discover whether the user
    # has specified VM parameters at submission time or not
    if BATCH_SYSTEM == 'LSF':
        BJOBS_OUPUT = commands.getstatusoutput('source /etc/profile.d/lsf.sh; bjobs -l ' + str(job_id))
        if BJOBS_OUPUT[0] == 0:
            if 'WNoDeS' in  str(BJOBS_OUPUT[1]):
                BJOBS = ' '.join(BJOBS_OUPUT[1].split('\n')).replace('                      ', '')
                for line in BJOBS.split(';'):
                    if 'WNoDeS_' in line:
                        CONFIG_LINE = line.strip().split(':')
                        if 'VM_CONFIG' in line:
                            COMPONENT = CONFIG_LINE[0].split('_')[3]
                            ELEMENT_CONFIG = CONFIG_LINE[0].split('_')[4]
                            VALUE = ':'.join(CONFIG_LINE[1:len(CONFIG_LINE)])
                            vmParameters['VM_CONFIG_%s_%s' % (COMPONENT, ELEMENT_CONFIG)] = VALUE
                        else:
                            VARIABLE = CONFIG_LINE[0].split('_')[2]
                            VALUE = ':'.join(CONFIG_LINE[1:len(CONFIG_LINE)])
                            vmParameters[VARIABLE] = VALUE

    elif BATCH_SYSTEM == 'PBS':
        try:
            from xml.etree import ElementTree
        except ImportError:
            try:
                from elementtree import ElementTree  # pylint: disable-msg=F0401
            except ImportError:
                debug(USERNAME, 'Failed to import elementTree module')
                sys.exit(PX_CONFIG['FAIL_RETURN_STATUS'])

        cmdLine = ('qstat -x %s' % str(job_id))
        bjobsCmdOutput = commands.getstatusoutput(cmdLine)

        if bjobsCmdOutput[0] == 0:
            JOB_INFO = ElementTree.fromstring(bjobsCmdOutput[1])
            for i in JOB_INFO.getiterator():
                if i.tag == 'Variable_List':
                    VARIABLES = i.text
                    break
                else:
                    VARIABLES = ''

            # Try to find WNoDeS parameters
            for k in VARIABLES.split(','):
                if 'WNoDeS' in k:
                    CONFIG_LINE = k.split('=')
                    if 'VM_CONFIG' in k:
                        COMPONENT = CONFIG_LINE[0].split('_')[3]
                        ELEMENT_CONFIG = CONFIG_LINE[0].split('_')[4]
                        VALUE = ':'.join(CONFIG_LINE[1:len(CONFIG_LINE)])
                        vmParameters['VM_CONFIG_%s_%s' % (COMPONENT, ELEMENT_CONFIG)] = VALUE
                    else:
                        VARIABLE = CONFIG_LINE[0].split('_')[2]
                        VALUE = ':'.join(CONFIG_LINE[1:len(CONFIG_LINE)])
                        vmParameters[VARIABLE] = VALUE

    HOST = '%s.%s' % (hostname, PX_CONFIG['LOCAL_DOMAIN'])
    try:
        c = Connection(HOST, int(BAIT_PORT))
        msg = {'requestVMInstance': [job_id, vmParameters]}
        OUTPUT = c.sendRequest(msg)
        RETURN_STATUS = OUTPUT[0]
    except:
        exctype, value = sys.exc_info()[:2]
        err_msg = "requestVMInstance failed: host %s, port: %s, \
        error %s: %s" % (HOST, BAIT_PORT, exctype, value)
        debug(USERNAME, err_msg)

        RETURN_STATUS = PX_CONFIG['FAIL_RETURN_STATUS']
    sys.exit(RETURN_STATUS)


else:
    # Support for mixed mode.
    # This host is not a BAIT. Request came from a Real WN,
    # a Virtual WN or a WN which support WNoDeS mixed mode

    if vmParameters['TYPE'].upper() == 'BATCH_REAL':
        if not vmParameters['PX_SCRIPT'] == '':
            PX_REPORT = commands.getstatusoutput(vmParameters['PX_SCRIPT'])
        else:
            PX_REPORT = [0, '']

        if PX_REPORT[0] == 0:
            sys.exit(0)
        else:
            debug(USERNAME, str(PX_REPORT))
            sys.exit(PX_CONFIG['FAIL_RETURN_STATUS'])

    else:
        BAIT_HOST = whoIs_TheBait(PX_CONFIG, USERNAME)

        if BAIT_HOST[0] == 0:
            # this is a job managed by WNoDeS
            BAIT_HOST = BAIT_HOST[1]

            if not vmParameters['PX_SCRIPT'] == '':
                PX_REPORT = commands.getstatusoutput(vmParameters['PX_SCRIPT'])
            else:
                PX_REPORT = [0, '']

            if PX_REPORT[0] == 0:
                try:
                    c = Connection(BAIT_HOST, int(BAIT_PORT))
                    msg = {'reportPreExecutionScript': [job_id, 0, 'Everything is ok']}
                    OUTPUT = c.sendRequest(msg)
                    RETURN_STATUS = OUTPUT[0]
                except:
                    exctype, value = sys.exc_info()[:2]
                    err_msg = "requestVMInstance failed: host %s, port: %s, \
                                                    error %s: %s" % (BAIT_HOST,
                                                                     BAIT_PORT,
                                                                     exctype,
                                                                     value)
                    debug(USERNAME, err_msg)
                    RETURN_STATUS = PX_CONFIG['FAIL_RETURN_STATUS']

            else:
                try:
                    c = Connection(BAIT_HOST, int(BAIT_PORT))
                    msg = {'reportPreExecutionScript': [job_id, 1, str(PX_REPORT)]}
                    OUTPUT = c.sendRequest(msg)
                    RETURN_STATUS = [PX_CONFIG['FAIL_RETURN_STATUS']]
                except:
                    exctype, value = sys.exc_info()[:2]
                    err_msg = "requestVMInstance failed: host %s, port: %s, \
                                                    error %s: %s" % (BAIT_HOST,
                                                                     BAIT_PORT,
                                                                     exctype,
                                                                     value)
                    debug(USERNAME, err_msg)
                    RETURN_STATUS = [PX_CONFIG['FAIL_RETURN_STATUS']]

            sys.exit(RETURN_STATUS)

        else:
            sys.exit(PX_CONFIG['FAIL_RETURN_STATUS'])
