#------------------------------------------------------------------------------
# Copyright 2008-2012 Istituto Nazionale di Fisica Nucleare (INFN)
#
# Licensed under the EUPL, Version 1.1 only (the "Licence").
# You may not use this work except in compliance with the Licence.
# You may obtain a copy of the Licence at:
#
# http://joinup.ec.europa.eu/system/files/EN/EUPL%20v.1.1%20-%20Licence.pdf
#
# Unless required by applicable law or agreed to in
# writing, software distributed under the Licence is
# distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
# either express or implied.
# See the Licence for the specific language governing
# permissions and limitations under the Licence.
#------------------------------------------------------------------------------
"""
A module used to create a thread pool to which tasks can be dispatched.
"""
__author__ = "Davide Salomoni"
__version__ = "2.1"

import sys
import threading
import Queue


class ThreadPool(object):
    """
    A class implementing a thread-pool.

    Constructor parameters:
        - size (integer >0 ): the number of pooled threads.
        - wait (boolean): wait for an available thread when calling dispatch().
        - join (boolean): wait for the pooled threads to join when stopping the
          pool.
        - errback (method): the method to call in case an error occurs running
          the user-provided function. errback will be called with the traceback
          tuple generated by sys.exc_info().
        - out (queue): a queue where task output will be put.

    Accessible instance variables:
        - size: the current number of pooled threads.
                It may be changed dynamically.
        - busy: the approximate number of busy threads.

    Call the stop() method when done with the pool so that the pooled threads
    can be stopped.
    """

    def __init__(self, size=10, wait=True, join=True, errback=None, out=None):
        assert size > 0, "the number of pooled threads must be positive"
        self._size = size
        self._wait = wait
        self._join = join
        self._errback = errback

        # a "task" is a tuple (function, args, kwargs)
        self._taskQueue = Queue.Queue()
        self._scheduledTasks = []
        self._threads = []

        # a lock used to protect internal shared data structures like the
        # lists of threads and scheduled tasks.
        self._controlLock = threading.Lock()

        for _n in range(self._size):
            self._addThread()

        self._poolActive = True

        # output queue (optional)
        self.outQueue = out

    def _addThread(self):
        """
        Internal method to add a thread to the pool.
        NEEDS the control lock because it changes self._threads.
        """
        newThread = threading.Thread(target=self._worker)
        self._threads.append(newThread)
        newThread.start()

    def _removeThread(self):
        """
        Internal method to schedule the removal of a thread from the pool.
        Threads will remove themselves from self._threads before exiting.
        """
        self._taskQueue.put(None)

    def _scheduleTask(self, task):
        """
        Internal method to schedule a task.
        NEEDS the control lock because it changes self._scheduledTasks.
        """
        self._scheduledTasks.append(1)
        self._taskQueue.put(task)

    def _worker(self):
        """A worker, waiting for tasks to appear in the task queue."""
        while True:
            # get a task, blocking if there is none
            task = self._taskQueue.get()

            if task is None:
                # exit if it is a "stop" task
                break
            else:
                # run the task
                func, args, kw = task
                try:
                    result = func(*args, **kw)
                except:
                    if self._errback:
                        result = self._errback(sys.exc_info())
                    else:
                        # if there is no error handler,
                        # the crash is silently ignored
                        result = None

                # put the result in the output Queue
                if self.outQueue:
                    self.outQueue.put(result)

                # decrease the number of pending tasks
                self._controlLock.acquire()
                self._scheduledTasks.pop()
                self._controlLock.release()

        # before going away, remove ourselves from the list of threads
        self._controlLock.acquire()
        self._threads.remove(threading.currentThread())
        self._controlLock.release()

    def _getRunning(self):
        """
        Internal function to get the approximate number of tasks
        being executed.
        """
        self._controlLock.acquire()
        n = len(self._scheduledTasks)
        self._controlLock.release()
        return n

    def _getSize(self):
        """Internal function to return the pool size."""
        return self._size

    def _setSize(self, newsize):
        """
        Internal function to resize the thread pool.
        It adds or removes threads from the pool as needed.
        """
        assert newsize > 0, "the number of pooled threads must be positive"
        self._controlLock.acquire()

        while newsize > self._size:
            # add threads
            self._addThread()
            self._size = self._size + 1

        while newsize < self._size:
            # remove threads. Running threads will run to completion.
            self._removeThread()
            self._size = self._size - 1

        self._controlLock.release()

    def dispatch(self, func, *args, **kw):
        """
        Execute a task on one of the pooled threads.

        If this ThreadPool instance was created with wait=False, and if there
        are no free threads, dispatch() will return False, without waiting
        for a thread to be available. It is then the responsibility of the
        caller to take appropriate actions.

        If this ThreadPool instance was created with wait=False, but there is
        an available free thread, the task will be scheduled and dispatch()
        will return True.

        If this ThreadPool instance was created with wait=True, dispatch()
        will wait until there is an available free thread, schedule the task,
        and return True.
        """
        assert self._poolActive, "the pool is not active"

        task = (func, args, kw)

        # check if there are free threads
        self._controlLock.acquire()
        ret = True
        if len(self._scheduledTasks) < self._size:
            # there is at list one free thread, schedule the task right now
            self._scheduleTask(task)
        else:
            if self._wait:
                # schedule the task
                self._scheduleTask(task)
            else:
                # don't schedule the task and return False
                ret = False

        self._controlLock.release()
        return ret

    def stop(self):
        """Stop the thread pool."""
        self._poolActive = False

        for _n in range(self._size):
            self._removeThread()

        if self._join:
            # wait until all have joined
            self._controlLock.acquire()
            threads = self._threads[:]
            self._controlLock.release()
            for t in threads:
                t.join()

    # properties to access some internal variables
    size = property(_getSize, _setSize)
    busy = property(_getRunning)

if __name__ == "__main__":
    # tests
    import time
    import random

    poolSize = 50
    numTasks = 72
    outQueue = Queue.Queue()
    DEBUG = False

    def f(s):
        if DEBUG:
            if s > 5.0:
                raise RuntimeError("This a spectacular error!")

        time.sleep(s)
        return ((threading.currentThread().getName(), s))

    def my_handler(info):
        print "**** Exception occurred: %s" % (info,)

    print "ThreadPool test: creating a pool of %s threads\n" % poolSize
    try:
        try:
            p = ThreadPool(size=poolSize, errback=my_handler, out=outQueue)
            print "Dispatching task",
            for t in range(numTasks):
                print t + 1, sys.stdout.flush()
                p.dispatch(f, random.random() * 6)
            print "\n"
    
            while p.busy:
                print "Pool occupancy: %s%%" \
                       % (min(100, int(p.busy / float(p.size) * 100)))
                time.sleep(0.2)
    
        except KeyboardInterrupt:
            print "\nWaiting for running threads to terminate..."
        else:
            print ""

    finally:
        p.stop()  # don't forget to call this before exiting
        print "Results:"
        try:
            i = 1
            while True:
                r = outQueue.get(block=False)
                print "%3d: %s" % (i, r)
                i = i + 1
        except Queue.Empty:
            pass
    
        print "\nTest completed."
