#!/usr/bin/env python
#------------------------------------------------------------------------------
# Copyright 2008-2012 Istituto Nazionale di Fisica Nucleare (INFN)
#
# Licensed under the EUPL, Version 1.1 only (the "Licence").
# You may not use this work except in compliance with the Licence.
# You may obtain a copy of the Licence at:
#
# http://joinup.ec.europa.eu/system/files/EN/EUPL%20v.1.1%20-%20Licence.pdf
#
# Unless required by applicable law or agreed to in
# writing, software distributed under the Licence is
# distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
# either express or implied.
# See the Licence for the specific language governing
# permissions and limitations under the Licence.
#------------------------------------------------------------------------------
"""
WNoDeS socket communication utility.
"""

from OpenSSL import SSL  # @UnresolvedImport
import SocketServer  # @UnresolvedImport
import os
import pickle
import socket
import sys
import threading
import traceback

__version__ = (0, 0, 1)
__short_name__ = os.path.basename(os.path.splitext(sys.argv[0])[0])
__long_name__ = __short_name__ + "-" + ".".join(map(str, __version__))
__dir_name__ = os.path.dirname(sys.argv[0])
__minimum_python__ = (2, 4, 0)


def isOpen(host, port):
    "Check if a TCP connection to the pair (host,port) can be established."
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(1)
    try:
        s.connect((host, int(port)))
    except Exception:
        status = False
    else:
        status = True
    s.close()
    return status


def verify_cb(conn, cert, errnum, depth, ok):
    # This obviously has to be updated
    print 'Got certificate: %s' % cert.get_subject()
    return ok


class WnodesThreadingTCPServer(SocketServer.ThreadingTCPServer):
    def get_request(self):
        if self.RequestHandlerClass.encrypt:
            context = SSL.Context(SSL.SSLv23_METHOD)
            context.set_verify(SSL.VERIFY_PEER |
                               SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
                               verify_cb)  # Demand a certificate
            context.use_privatekey_file(self.RequestHandlerClass.server_key)
            context.use_certificate_file(self.RequestHandlerClass.server_cert)
            context.load_verify_locations(self.RequestHandlerClass.ca_cert)

            sock = SSL.Connection(context, self.socket)
            return sock.accept()

        else:
            return SocketServer.ThreadingTCPServer.get_request(self)

    def shutdown_request(self, request):
        """Called to shutdown and close an individual request."""
        try:
            #explicitly shutdown.  socket.close() merely releases
            #the socket and waits for GC to perform the actual close.
            if self.RequestHandlerClass.encrypt:
                request.shutdown()
            else:
                request.shutdown(socket.SHUT_WR)
        except socket.error:
            pass  # some platforms may raise ENOTCONN here
        self.close_request(request)

try:
    from wnodes.utils import utils
    #from wnodes.utils.utils import LOG_LEVELS, whoami
except ImportError:
    sys.exit("%s: python module 'utils' not found." % __short_name__)


class ServerRequestHandler(SocketServer.StreamRequestHandler):
    """ TCP server handling connections from the HV or management processes."""
    encrypt = False
    server_key = None
    server_cert = None
    ca_cert = None
    """Define self.rfile and self.wfile for stream sockets."""

    # Default buffer sizes for rfile, wfile.
    # We default rfile to buffered because otherwise it could be
    # really slow for large data (a getc() call per byte); we make
    # wfile unbuffered because (a) often after a write() we want to
    # read and we need to flush the line; (b) big writes to unbuffered
    # files are typically optimized by stdio even when big reads
    # aren't.
    rbufsize = -1
    wbufsize = 0

    # A timeout to apply to the request socket, if not None.
    timeout = None

    # Disable nagle algoritm for this socket, if True.
    # Use only when wbufsize != 0, to avoid small packets.
    disable_nagle_algorithm = False

    def __init__(self, request, client_address, server):
        if self.encrypt:
            self.setup = self.__ssl_setup
        SocketServer.StreamRequestHandler.__init__(self, request,
                                                   client_address, server)

    def __ssl_setup(self):
        if self.server_key == None or self.server_cert == None:
            sys.exit("ServerRequestHandler not correctly initialized " +
                     "to use SSL encryption! (Missing key or cert file)")
        self.connection = self.request
        self.rfile = socket._fileobject(self.connection, 'rb', self.rbufsize)
        self.wfile = socket._fileobject(self.connection, 'wb', self.wbufsize)

    extObj = None
    serviceDispatchTable = None

    def updateLog(self, *args):
        try:
            print str(args[0])
        except:
            pass

    def invalid_request(self, *arguments):
        parameters = arguments[0]
        self.extObj.updateLog("executing invalid_request", "error")
        #parameters = [ request ]
        request = parameters[0]

        response = [1, "invalid request: %s" % request]
        #print response
        self.extObj.updateLog(response[1], "error")
        return response

    def handle(self):
        if (self.extObj == None
            or type(self.serviceDispatchTable) != type({})):
            response = "ServerRequestHandler not correctly initialized!\n"
            #print response
            output = [1, response]
            output = pickle.dumps(output).replace('\n', '1%2') + '\n'
            self.wfile.write(output)
            return

        try:
            self.extObj.updateLog(' -- Connected from %s '
                                  % str(self.client_address))
        except:
            tp, vl, tb = sys.exc_info()
            output = [1, ('Exception calling updateLog of object type' +
                          ' %s: %s, %s, %s'
                          % (str(type(self.extObj)),
                             str(tp),
                             str(vl),
                             str(tb)))]
            output = pickle.dumps(output).replace('\n', '1%2') + '\n'
            self.wfile.write(output)
            return
        self.serviceDispatchTable['invalid_request'] = self.invalid_request
        sourceAddr = ""
        try:
            # try to get the client hostname
            sourceAddr = socket.gethostbyaddr(
                str(self.client_address[0]))[0].split('.')[0]
        except:
            # if hostname is unknown, use the IP address
            sourceAddr = str(self.client_address)

        request = None

        try:
            received_data = self.rfile.readline().strip()
        except SSL.Error, errors:
            err_msg = "Errors receiving data: %s" % str(errors)
            self.extObj.updateLog(err_msg, "error")
            output = [1, err_msg]
            returnOutput = pickle.dumps(output).replace('\n', '1%2')
            self.wfile.write(returnOutput + '\n')
            return

        if not received_data:
            return

        try:
            request = pickle.loads(received_data.replace('1%2', '\n'))
        except:
            err_msg = "Errors depickling data!"
            self.extObj.updateLog(err_msg, "error")
            output = [1, err_msg]
            returnOutput = pickle.dumps(output).replace('\n', '1%2')
            self.wfile.write(returnOutput + '\n')
            return

        self.extObj.updateLog("Requested service: %s" % str(request))

        if type(request) == type({}) and len(request.keys()) == 1:

            #self.ns.updateLog('Requested Service: %s' % str(request))

            requestedService = request.keys()[0]
            serviceParameters = request[requestedService]

            if not(request.keys()[0] in self.serviceDispatchTable.keys()):
                error = "no such service: %s" % requestedService
                serviceParameters = [error]
                requestedService = 'invalid_request'
            if not(type(serviceParameters) == type([])):
                error = ("'serviceParameters' match a %s and not an array"
                         % type(serviceParameters))
                serviceParameters = [error]
                requestedService = 'invalid_request'

            output = 0
            try:
                output = self.serviceDispatchTable[requestedService](
                    serviceParameters, sourceAddr)
            except:
                tp, vl, tb = sys.exc_info()
                output = [1, ('Exception caught in function %s with parameters'
                              % requestedService
                              + ' %s: %s, %s, %s '
                              % (str(serviceParameters),
                                 str(tp),
                                 str(vl),
                                 str(tb)))]
                traceback.print_exception(tp, vl, tb)

            # NOTE: output must be like: output = [ (int), (object) ]
            # where (int = 0) if the request was successful
            try:
                if type(output[0]) != type(int()):
                    raise Exception
            except:
                output = [1, ('Data structure returned by function %s, '
                              % requestedService +
                              'with parameters %s is not formatted as expected'
                              % str(serviceParameters))]

            if output[0] != 0:
                self.extObj.updateLog(str(output[1]), "error")

            try:
                returnOutput = pickle.dumps(output).replace('\n', '1%2')
            except:
                resp = [1, ('Error pickling the object returned by function %s'
                            % requestedService +
                            ', with parameters %s'
                            % str(serviceParameters))]
                returnOutput = pickle.dumps(resp).replace('\n', '1%2')

            self.wfile.write(returnOutput + '\n')

        else:
            err = ('Received data are not formatted as expected : %s'
                   % str(request))
            self.extObj.updateLog(err, "error")
            output = [1, err]
            returnOutput = pickle.dumps(output).replace('\n', '1%2')
            self.wfile.write(returnOutput + '\n')


def initServerRequestHandler(extObj, SDT, encrypt=False, keyfile='',
    certfile='', cafile=''):
    ServerRequestHandler.extObj = extObj
    ServerRequestHandler.serviceDispatchTable = SDT
    ServerRequestHandler.encrypt = encrypt

    if encrypt:
        ServerRequestHandler.extObj.updateLog("Server side encryption ENABLED",
            "info")
        if (not keyfile) or (not certfile):
            sys.exit("Enabling encryption requires: " +
                     "a 'keyfile', a 'certfile' and an 'authorized keys file'")
        ServerRequestHandler.server_key = keyfile
        ServerRequestHandler.server_cert = certfile
        ServerRequestHandler.ca_cert = cafile
    else:
        ServerRequestHandler.extObj.updateLog(
            "Server side encryption DISABLED", "info")


class ClientRequestHandler:
    """  """
    def __init__(self, host, port, key=None, cert=None, ca=None):
        self.host = host
        self.port = port
        self.client_key = None
        self.client_cert = None
        self.ca_cert = None
        self.encrypt = False
        if key:
            try:
                self.encrypt = True
                self.client_key = key
                self.client_cert = cert
                self.ca_cert = ca
            except:
                sys.exit("Unable to initialize Crypting Engine!")
                #print "client_key %s" %self.client_key.check_key()

    def updateLog(self, *args):
        try:
            print str(args[0])
        except:
            pass

    def encryptMessage(self, msg, pubkey):
        return self.cryptEngine.encryptMessage(msg, pubkey)

    def decryptMessage(self, msg):
        return self.cryptEngine.decryptMessage(msg)

    def _serialize_input(self, msg, server_key=None, auth_string=None):
        try:
            if len(msg.keys()) == 1:
                if server_key and auth_string:
                    msg['auth'] = auth_string
                    msg = self.encryptMessage(msg, server_key)
                    return msg
                else:
                    return str(pickle.dumps(msg).replace('\n', '1%2'))
            else:
                self.updateLog('%s: msg len(%s) '
                                % (utils.whoami(), len(msg.keys()) +
                                'is not 1 so it is not formatted as expected'),
                               "error")
                response = (1, None)
                return response
        except pickle.PicklingError:
            sys.exit('Data received cannot be loaded')
        except:
            self.updateLog('%s: %s'
                           % (utils.whoami(), str(sys.exc_info()[0])),
                           "error")
            response = (1, '')
            return response

    def _marshall_output(self, response):
        try:
            output = pickle.loads(response.replace('1%2', '\n'))
            if type(output) == type([]):
                return [0, output]
            else:
                return [1, "wrong exit syntax: %s" % output]

        except pickle.UnpicklingError:
            return [1, 'Data received cannot be loaded: Unpickling error']
        except Exception:
            a, b, c = sys.exc_info()
            traceback.print_exception(a, b, c)
            return [1, ('Data received cannot be loaded for this reason: ' +
                        '%s, %s, %s'
                        % (sys.exc_info()[:]))]

    def sendRequest(self, *args, **kwargs):
        if len(args) == 1:
            host = self.host
            port = self.port
            msg = args[0]
        elif len(args) == 3:
            host = args[0]
            port = args[1]
            msg = args[2]
        else:
            sys.exit('Function sendRequest takes 1 or 3 arguments (%S given)'
                     % len(args))
        enc = False
        try:
            enc = kwargs['encrypt']
            self.updateLog(('Keyword encrypt found for request %s, '
                            % str(msg)
                            + 'setting encryption to: %s'
                            % str(enc)),
                           "info")
        except:
#            self.updateLog(('Keyword encrypt not found for request %s,'
#                            % str(msg) +
#                            ' setting default: %s'
#                            % str(self.encrypt)),
#                           "info")
            enc = self.encrypt

        return self._sendRequest(host, port, msg, enc)

    def _sendRequest(self, host, port, msg, enc):
        """
        Send a msg to a WNoDeS TCP socket server.
        Msg format MUST BE a dictionary with one value-pair attribute.
        Key  MUST BE the method name you want to execute on the NameServer.
        Value MUST BE the options method weather they exist.
        If msg is not in this format an error is raised

        It returns a tuple with this format (status, data).
        """
#        print ("sending request to %s:%s with message: %s"
#               % (host, str(port), str(msg)))
        if type(port) == type(int()):
            pass
        elif type(port) == type(''):
            try:
                port = int(port)
            except:
                err = ('Invalid format for port %s:'
                       % port +
                       'string cannot be converted to int')
                self.updateLog(err, "error")
                response = [1, err]
                return response
        else:
                err = ('Invalid format for port %s: must be <int> or <string>'
                       % port)
                self.updateLog(err, "error")
                response = [1, err]
                return response
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            response = None
            if enc:
                context = SSL.Context(SSL.SSLv23_METHOD)
                # Demand a certificate
                context.set_verify(SSL.VERIFY_PEER, verify_cb)
                context.use_privatekey_file(self.client_key)
                context.use_certificate_file(self.client_cert)
                # WARNING CA cert filename hardcoded! should be config file
                #context.load_verify_locations('/root/keys/CA.cert')
                context.load_verify_locations(self.ca_cert)

                sock = SSL.Connection(context, sock)
                sock.connect((host, port))
                sock.do_handshake()
                self.updateLog(sock.state_string())
            else:
                sock.connect((host, port))

            socketfile = socket._fileobject(sock, 'rb')
            request = self._serialize_input(msg)
            sock.sendall(request + '\n')
            response = socketfile.readline().strip()

            if enc:
                sock.shutdown()
            else:
                sock.shutdown(2)
            sock.close()
            return self._marshall_output(response)

        except socket.error:
            a, b, c = sys.exc_info()
            traceback.print_exception(a, b, c)
            err = 'Server host %s:%s is unreachable' % (host, port)
            self.updateLog(err, "error")
            response = [1, err]
            return response
        except:
            #print (' %s : %s' % (sys.exc_type, sys.exc_value))
            a, b, c = sys.exc_info()
            traceback.print_exception(a, b, c)
            err = 'Exception in %s: %s' % (utils.whoami(), str(sys.exc_info()))
            self.updateLog(err, "error")
            response = [1, err]
            return response


class Connection(ClientRequestHandler):
    """ Provide a socket-threaded connection using WNoDeS syntax. """

    # lock used to rotate logfiles
    LOCK_LOG = threading.Lock()

    def __init__(self, host, port, extObj=None):
        #threading.Thread.__init__(self)
        ClientRequestHandler.__init__(self, host, port)
        self.host = host
        self.port = int(port)
        try:
            self.updateLog = self.extObj.updateLog
        except:
            pass


def main():
    pass

if __name__ == "__main__":
    main()
