join() threads instead of guessing if they're running

- new book keeping of dispatcher's async threads so they can be joined
   when stopping
 - updater, webhook & dispatcher threads are now kept on
   Updater.__threads so they can be joined at the end

refs #175
This commit is contained in:
Noam Meltzer 2016-02-09 23:08:27 +02:00
parent fd7baa2236
commit d415a60ebf
2 changed files with 47 additions and 24 deletions

View file

@ -22,7 +22,7 @@
import logging
from functools import wraps
from inspect import getargspec
from threading import Thread, BoundedSemaphore, Lock, Event
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread
from re import match
from time import sleep
@ -33,7 +33,8 @@ H = NullHandler()
logging.getLogger(__name__).addHandler(H)
semaphore = None
running_async = 0
async_threads = set()
""":type: set[Thread]"""
async_lock = Lock()
@ -58,11 +59,10 @@ def run_async(func):
"""
A wrapper to run a thread in a thread pool
"""
global running_async, async_lock
result = func(*pargs, **kwargs)
semaphore.release()
with async_lock:
running_async -= 1
async_threads.remove(current_thread())
return result
@wraps(func)
@ -70,11 +70,10 @@ def run_async(func):
"""
A wrapper to run a function in a thread
"""
global running_async, async_lock
thread = Thread(target=pooled, args=pargs, kwargs=kwargs)
semaphore.acquire()
with async_lock:
running_async += 1
async_threads.add(thread)
thread.start()
return thread

View file

@ -89,6 +89,8 @@ class Updater:
self.is_idle = False
self.httpd = None
self.__lock = Lock()
self.__threads = []
""":type: list[Thread]"""
def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2):
"""
@ -120,6 +122,7 @@ class Updater:
thr = Thread(target=self._thread_wrapper, name=name,
args=(target,) + args, kwargs=kwargs)
thr.start()
self.__threads.append(thr)
def _thread_wrapper(self, target, *args, **kwargs):
thr_name = current_thread().name
@ -211,8 +214,6 @@ class Updater:
sleep(cur_interval)
self.logger.info('Updater thread stopped')
@staticmethod
def _increase_poll_interval(current_interval):
# increase waiting times on subsequent errors up to 30secs
@ -256,7 +257,6 @@ class Updater:
raise TelegramError('SSL Certificate invalid')
self.httpd.serve_forever(poll_interval=1)
self.logger.info('Updater thread stopped')
def stop(self):
"""
@ -266,25 +266,49 @@ class Updater:
self.job_queue.stop()
with self.__lock:
if self.running:
self.running = False
self.logger.info('Stopping Updater and Dispatcher...')
self.logger.debug('This might take a long time if you set a '
'high value as polling timeout.')
if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None
self.running = False
self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()
while dispatcher.running_async > 0:
sleep(1)
self._stop_httpd()
self._stop_dispatcher()
self._join_threads()
# async threads must be join()ed only after the dispatcher
# thread was joined, otherwise we can still have new async
# threads dispatched
self._join_async_threads()
self.logger.debug("Dispatcher stopped.")
def _stop_httpd(self):
if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None
def _stop_dispatcher(self):
self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()
def _join_async_threads(self):
with dispatcher.async_lock:
threads = list(dispatcher.async_threads)
total = len(threads)
for i, thr in enumerate(threads):
self.logger.info(
'Waiting for async thread {0}/{1} to end'.format(i, total))
thr.join()
self.logger.debug(
'async thread {0}/{1} has ended'.format(i, total))
def _join_threads(self):
for thr in self.__threads:
self.logger.info(
'Waiting for {0} thread to end'.format(thr.name))
thr.join()
self.logger.debug('{0} thread has ended'.format(thr.name))
self.__threads = []
def signal_handler(self, signum, frame):
self.is_idle = False