mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-30 03:09:21 +02:00
updater/dispatcher: on exception stop all threads
This commit is contained in:
parent
7ebbc60694
commit
4a5001668d
2 changed files with 39 additions and 16 deletions
telegram
|
@ -50,6 +50,9 @@ def run_async(func):
|
|||
function:
|
||||
"""
|
||||
|
||||
# TODO: handle exception in async threads
|
||||
# set a threading.Event to notify caller thread
|
||||
|
||||
@wraps(func)
|
||||
def pooled(*pargs, **kwargs):
|
||||
"""
|
||||
|
@ -132,7 +135,7 @@ class Dispatcher:
|
|||
update_queue (telegram.UpdateQueue): The synchronized queue that will
|
||||
contain the updates.
|
||||
"""
|
||||
def __init__(self, bot, update_queue, workers=4):
|
||||
def __init__(self, bot, update_queue, workers=4, exception_event=None):
|
||||
self.bot = bot
|
||||
self.update_queue = update_queue
|
||||
self.telegram_message_handlers = []
|
||||
|
@ -147,6 +150,7 @@ class Dispatcher:
|
|||
self.logger = logging.getLogger(__name__)
|
||||
self.running = False
|
||||
self.__stop_event = Event()
|
||||
self.__exception_event = exception_event or Event()
|
||||
|
||||
global semaphore
|
||||
if not semaphore:
|
||||
|
@ -164,6 +168,11 @@ class Dispatcher:
|
|||
self.logger.warning('already running')
|
||||
return
|
||||
|
||||
if self.__exception_event.is_set():
|
||||
msg = 'reusing dispatcher after exception event is forbidden'
|
||||
self.logger.error(msg)
|
||||
raise TelegramError(msg)
|
||||
|
||||
self.running = True
|
||||
self.logger.info('Dispatcher started')
|
||||
|
||||
|
@ -173,6 +182,11 @@ class Dispatcher:
|
|||
update, context = self.update_queue.get(True, 1, True)
|
||||
except Empty:
|
||||
if self.__stop_event.is_set():
|
||||
self.logger.info('orderly stopping')
|
||||
break
|
||||
elif self.__stop_event.is_set():
|
||||
self.logger.critical(
|
||||
'stopping due to exception in another thread')
|
||||
break
|
||||
continue
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ Telegram bots intuitive."""
|
|||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from threading import Thread, Lock
|
||||
from threading import Thread, Lock, current_thread, Event
|
||||
from time import sleep
|
||||
import subprocess
|
||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||
|
@ -80,8 +80,9 @@ class Updater:
|
|||
self.bot = Bot(token, base_url)
|
||||
self.update_queue = UpdateQueue()
|
||||
self.job_queue = JobQueue(self.bot, job_queue_tick_interval)
|
||||
self.dispatcher = Dispatcher(self.bot, self.update_queue,
|
||||
workers=workers)
|
||||
self.__exception_event = Event()
|
||||
self.dispatcher = Dispatcher(self.bot, self.update_queue, workers,
|
||||
self.__exception_event)
|
||||
self.last_update_id = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = False
|
||||
|
@ -107,22 +108,30 @@ class Updater:
|
|||
if not self.running:
|
||||
self.running = True
|
||||
|
||||
# Create Thread objects
|
||||
dispatcher_thread = Thread(target=self.dispatcher.start,
|
||||
name="dispatcher")
|
||||
updater_thread = Thread(target=self._start_polling,
|
||||
name="updater",
|
||||
args=(poll_interval,
|
||||
timeout,
|
||||
network_delay))
|
||||
|
||||
# Start threads
|
||||
dispatcher_thread.start()
|
||||
updater_thread.start()
|
||||
# Create & start threads
|
||||
self._init_thread(self.dispatcher.start, "dispatcher")
|
||||
self._init_thread(self._start_polling, "updater",
|
||||
poll_interval, timeout, network_delay)
|
||||
|
||||
# Return the update queue so the main thread can insert updates
|
||||
return self.update_queue
|
||||
|
||||
def _init_thread(self, target, name, *args, **kwargs):
|
||||
thr = Thread(target=self._thread_wrapper, name=name,
|
||||
args=(target,) + args, kwargs=kwargs)
|
||||
thr.start()
|
||||
|
||||
def _thread_wrapper(self, target, *args, **kwargs):
|
||||
thr_name = current_thread()
|
||||
self.logger.debug('{0} - started'.format(thr_name))
|
||||
try:
|
||||
target(*args, **kwargs)
|
||||
except Exception:
|
||||
self.__exception_event.set()
|
||||
self.logger.exception('unhandled exception')
|
||||
raise
|
||||
self.logger.debug('{0} - ended'.format(thr_name))
|
||||
|
||||
def start_webhook(self,
|
||||
listen='127.0.0.1',
|
||||
port=80,
|
||||
|
|
Loading…
Add table
Reference in a new issue