diff --git a/voussoirkit/threadpool.py b/voussoirkit/threadpool.py index c160cba..9cf1224 100644 --- a/voussoirkit/threadpool.py +++ b/voussoirkit/threadpool.py @@ -4,28 +4,49 @@ of threadpool in use: 1. Powering a single api scraping generator with many threads: -pool = threadpool.ThreadPool(thread_count, paused=True) -job_gen = ({'function': api.get_item, 'kwargs': {'id': i}} for i in range(lower, upper+1)) -pool.add_generator(job_gen) -for job in pool.result_generator(): - if job.exception: - raise job.exception - if job.value is not None: - yield job.value +>>> pool = threadpool.ThreadPool(thread_count, paused=True) +>>> job_gen = ({'function': api.get_item, 'kwargs': {'id': i}} for i in range(lower, upper+1)) +>>> pool.add_generator(job_gen) +>>> for job in pool.result_generator(): +>>> if job.exception: +>>> raise job.exception +>>> if job.value is not None: +>>> yield job.value +2. Git-fetching a bunch of repositories with no error handling: + +>>> def git_fetch(d): +>>> command = [GIT, '-C', d, 'fetch', '--all'] +>>> print(command) +>>> subprocess.check_output(command, stderr=subprocess.STDOUT) +>>> +>>> def callback(job): +>>> if job.exception: +>>> print(f'{job.name} caused {job.exception}.') +>>> +>>> pool = threadpool.ThreadPool(thread_count, paused=False) +>>> kwargss = [{'function': git_fetch, 'args': [d], 'name': d, 'callback': callback} for d in dirs] +>>> pool.add_many(kwargss) +>>> pool.join() ''' import collections +import logging import queue import threading +import traceback from voussoirkit import lazychain from voussoirkit import sentinel +log = logging.getLogger('threadpool') + PENDING = sentinel.Sentinel('PENDING') RUNNING = sentinel.Sentinel('RUNNING') FINISHED = sentinel.Sentinel('FINISHED') RAISED = sentinel.Sentinel('RAISED') +NO_MORE_JOBS = sentinel.Sentinel('NO_MORE_JOBS') + NO_RETURN = sentinel.Sentinel('NO_RETURN', truthyness=False) NO_EXCEPTION = sentinel.Sentinel('NO_EXCEPTION', truthyness=False) @@ -35,6 +56,61 @@ class ThreadPoolException(Exception): class PoolClosed(ThreadPoolException): pass +class PooledThread: + def __init__(self, pool): + self.pool = pool + self.thread = threading.Thread(target=self.start) + self.thread.daemon = True + self.thread.start() + + def __repr__(self): + return f'PooledThread {self.thread}' + + def _run_once(self): + # Any exceptions caused by the job's primary function are already + # wrapped safely, but there are two other sources of potential + # exceptions: + # 1. A generator given to add_generator that encounters an exception + # while generating the kwargs causes get_next_job to raise. + # 2. The callback function given to the Job raises. + # It's hard to say what the correct course of action is, but I + # realllly don't want them taking down the whole worker thread. + try: + job = self.pool.get_next_job() + except BaseException: + traceback.print_traceback() + return + + if job is NO_MORE_JOBS: + return NO_MORE_JOBS + + log.debug('%s is running job %s.', self, job) + self.pool._running_count += 1 + try: + job.run() + except BaseException: + traceback.print_traceback() + self.pool._running_count -= 1 + + def join(self): + log.debug('%s is joining.', self) + self.thread.join() + + def start(self): + while True: + # Let's wait for jobs_available first and unpaused second. + # If the time between the two waits is very long, the worst thing + # that can happen is there are no more jobs by the time we get + # there, and the loop comes around again. On the other hand, if + # unpaused.wait is first and the time until available.wait is very + # long, we might wind up running a job despite the user pausing + # the pool in the interim. + self.pool._jobs_available.wait() + self.pool._unpaused_event.wait() + status = self._run_once() + if status is NO_MORE_JOBS and self.pool.closed: + break + class ThreadPool: ''' The ThreadPool is used to perform large numbers of tasks using a pool of @@ -70,34 +146,37 @@ class ThreadPool: if size < 1: raise ValueError(f'size must be >= 1, not {size}.') - self.max_size = size - self.paused = paused + self._unpaused_event = threading.Event() + if not paused: + self._unpaused_event.set() + + self._jobs_available = threading.Event() self._closed = False self._running_count = 0 self._result_queue = None self._pending_jobs = lazychain.LazyChain() self._job_manager_lock = threading.Lock() - self._all_done_event = threading.Event() - self._all_done_event.set() - def _job_finished(self): - ''' - When a job finishes, it will call here so that a new job can be started. - ''' - self._running_count -= 1 - - if not self.paused: - self.start() + self._size = size + self._threads = [PooledThread(pool=self) for x in range(size)] @property def closed(self): - return self.closed + return self._closed + + @property + def paused(self): + return not self._unpaused_event.is_set() @property def running_count(self): return self._running_count + @property + def size(self): + return self._size + def assert_not_closed(self): ''' If the pool is closed (because you called `join`), raise PoolClosed. @@ -122,9 +201,7 @@ class ThreadPool: kwargs=kwargs, ) self._pending_jobs.append(job) - - if not self.paused: - self.start() + self._jobs_available.set() return job @@ -142,9 +219,7 @@ class ThreadPool: these_jobs = (Job(pool=self, **kwargs) for kwargs in kwargs_gen) self._pending_jobs.extend(these_jobs) - - if not self.paused: - self.start() + self._jobs_available.set() def add_many(self, kwargss): ''' @@ -161,24 +236,48 @@ class ThreadPool: ''' self.assert_not_closed() + kwargss = list(kwargss) + if not kwargss: + raise ValueError(f'{kwargss} must not be empty.') + these_jobs = [Job(pool=self, **kwargs) for kwargs in kwargss] self._pending_jobs.extend(these_jobs) - - if not self.paused: - self.start() + self._jobs_available.set() return these_jobs + def get_next_job(self): + with self._job_manager_lock: + try: + job = next(self._pending_jobs) + except StopIteration: + # If we ARE closed, we want to keep the flag set so that all + # the threads can keep waking up and seeing no more jobs. + if not self.closed: + self._jobs_available.clear() + return NO_MORE_JOBS + + if self._result_queue is not None: + # This will block if the queue is full. + self._result_queue.put(job) + + return job + def join(self): ''' Permanently close the pool, preventing any new jobs from being added, and block until all jobs are complete. ''' + log.debug('%s is joining.', self) self._closed = True + # The threads which are currently paused at _jobs_available.wait() need + # to be woken up so they can realize the pool is closed and break. + self._jobs_available.set() self.start() - self._all_done_event.wait() + for thread in self._threads: + thread.join() - def result_generator(self): + def result_generator(self, *, buffer_size=None): ''' This generator will start the job pool, then yield finished/raised Job objects in the order they were added. Note that a slow job will @@ -197,52 +296,60 @@ class ThreadPool: When there are no more outstanding jobs, the generator will stop iteration and return. If the pool was paused before generating, it - will be paused again. + will be paused again. This prevents subsequently added jobs from being + lost as described. + + buffer_size: + The size of the buffer which holds jobs before they are yielded. + If you expect your production to outpace your consumption, you may + wish to set this value to prevent high memory usage. When the buffer + is full, new jobs will be blocked from starting. ''' if self._result_queue is not None: raise TypeError('The result generator is already open.') - self._result_queue = queue.Queue() + + self._result_queue = queue.Queue(maxsize=buffer_size or 0) was_paused = self.paused + self.start() - while (not self._all_done_event.is_set()) or (not self._result_queue.empty()): + # Considerations for the while loop condition: + # Why `jobs_available.is_set`: Consider a group of slow-running threads + # are launched and the jobs are added to the result_queue. The caller + # of this generator consumes all of them before the threads finish and + # start a new job. So, we need to watch jobs_available.is_set to know + # that even though the result_queue is currently empty, we can expect + # more to be ready soon and shouldn't break yet. + # Why `not results_queue.empty`: Consider a group of fast-running + # threads are launched, and exhaust all available jobs. So, we need to + # watch that result_queue is not empty and has more results. + # Why not `not closed`: After the pool is closed, the outstanding jobs + # still need to finish. Closing does not imply pausing or cancelling + # jobs. + while self._jobs_available.is_set() or not self._result_queue.empty(): job = self._result_queue.get() job.join() yield job self._result_queue.task_done() self._result_queue = None + if was_paused: - self.paused = True + self.pause() + + def pause(self): + self._unpaused_event.clear() def start(self): - self.paused = False - with self._job_manager_lock: - available = self.max_size - self._running_count - - no_more_jobs = False - for x in range(available): - try: - job = next(self._pending_jobs) - except StopIteration: - no_more_jobs = True - break - - self._all_done_event.clear() - job.start() - self._running_count += 1 - if self._result_queue is not None: - self._result_queue.put(job) - - if self._running_count == 0 and no_more_jobs: - self._all_done_event.set() + self._unpaused_event.set() class Job: ''' Each job contains one function that it will call when it is started. - If the function completes successfully you will find the return value in - `job.value`. If it raises an exception, you'll find it in `job.exception`, - although the thread itself will not raise. + If the function completes successfully (status is threadpool.FINISHED) you + will find the return value in `job.value`. If it raises an exception + (status is threadpool.RAISED), you'll find it in `job.exception`, although + the thread itself will not raise. All job threads are daemons and will not prevent the main thread from terminating. Call `job.join()` or `pool.join()` in the main thread to @@ -272,14 +379,8 @@ class Job: self.kwargs = kwargs self.value = NO_RETURN self.exception = NO_EXCEPTION - self._thread = None - # _joinme_lock works because it is possible for a single thread to block - # itself by calling `lock.acquire()` twice. The first call is here, - # and the second call is in `join` so that join will block until the - # lock is released by the job's finishing phase. - self._joinme_lock = threading.Lock() - self._joinme_lock.acquire() + self._done_event = threading.Event() def __repr__(self): if self.name: @@ -287,28 +388,22 @@ class Job: else: return f'<{self.status.name} Job on {self.function}>' - def _run(self): + def run(self): + self.status = RUNNING try: self.value = self.function(*self.args, **self.kwargs) self.status = FINISHED except BaseException as exc: self.exception = exc self.status = RAISED - self._thread = None - self._joinme_lock.release() - self.pool._job_finished() + if self.callback is not None: self.callback(self) + self._done_event.set() + def join(self): ''' Block until this job runs and completes. ''' - self._joinme_lock.acquire() - self._joinme_lock.release() - - def start(self): - self.status = RUNNING - self._thread = threading.Thread(target=self._run) - self._thread.daemon = True - self._thread.start() + self._done_event.wait()