mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-11 04:21:29 +01:00
Merge pull request #177 from tsnoam/master
join() threads for a cleaner stop procedure
This commit is contained in:
commit
04c86813b3
4 changed files with 170 additions and 90 deletions
|
@ -22,7 +22,7 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
from inspect import getargspec
|
||||
from threading import Thread, BoundedSemaphore, Lock, Event
|
||||
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread
|
||||
from re import match
|
||||
from time import sleep
|
||||
|
||||
|
@ -33,7 +33,8 @@ H = NullHandler()
|
|||
logging.getLogger(__name__).addHandler(H)
|
||||
|
||||
semaphore = None
|
||||
running_async = 0
|
||||
async_threads = set()
|
||||
""":type: set[Thread]"""
|
||||
async_lock = Lock()
|
||||
|
||||
|
||||
|
@ -58,11 +59,10 @@ def run_async(func):
|
|||
"""
|
||||
A wrapper to run a thread in a thread pool
|
||||
"""
|
||||
global running_async, async_lock
|
||||
result = func(*pargs, **kwargs)
|
||||
semaphore.release()
|
||||
with async_lock:
|
||||
running_async -= 1
|
||||
async_threads.remove(current_thread())
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
|
@ -70,11 +70,10 @@ def run_async(func):
|
|||
"""
|
||||
A wrapper to run a function in a thread
|
||||
"""
|
||||
global running_async, async_lock
|
||||
thread = Thread(target=pooled, args=pargs, kwargs=kwargs)
|
||||
semaphore.acquire()
|
||||
with async_lock:
|
||||
running_async += 1
|
||||
async_threads.add(thread)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
|
|
@ -89,6 +89,8 @@ class Updater:
|
|||
self.is_idle = False
|
||||
self.httpd = None
|
||||
self.__lock = Lock()
|
||||
self.__threads = []
|
||||
""":type: list[Thread]"""
|
||||
|
||||
def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2):
|
||||
"""
|
||||
|
@ -120,9 +122,10 @@ class Updater:
|
|||
thr = Thread(target=self._thread_wrapper, name=name,
|
||||
args=(target,) + args, kwargs=kwargs)
|
||||
thr.start()
|
||||
self.__threads.append(thr)
|
||||
|
||||
def _thread_wrapper(self, target, *args, **kwargs):
|
||||
thr_name = current_thread()
|
||||
thr_name = current_thread().name
|
||||
self.logger.debug('{0} - started'.format(thr_name))
|
||||
try:
|
||||
target(*args, **kwargs)
|
||||
|
@ -160,20 +163,10 @@ class Updater:
|
|||
if not self.running:
|
||||
self.running = True
|
||||
|
||||
# Create Thread objects
|
||||
dispatcher_thread = Thread(target=self.dispatcher.start,
|
||||
name="dispatcher")
|
||||
updater_thread = Thread(target=self._start_webhook,
|
||||
name="updater",
|
||||
args=(listen,
|
||||
port,
|
||||
url_path,
|
||||
cert,
|
||||
key))
|
||||
|
||||
# Start threads
|
||||
dispatcher_thread.start()
|
||||
updater_thread.start()
|
||||
# Create & start threads
|
||||
self._init_thread(self.dispatcher.start, "dispatcher"),
|
||||
self._init_thread(self._start_webhook, "updater", listen,
|
||||
port, url_path, cert, key)
|
||||
|
||||
# Return the update queue so the main thread can insert updates
|
||||
return self.update_queue
|
||||
|
@ -221,8 +214,6 @@ class Updater:
|
|||
|
||||
sleep(cur_interval)
|
||||
|
||||
self.logger.info('Updater thread stopped')
|
||||
|
||||
@staticmethod
|
||||
def _increase_poll_interval(current_interval):
|
||||
# increase waiting times on subsequent errors up to 30secs
|
||||
|
@ -266,7 +257,6 @@ class Updater:
|
|||
raise TelegramError('SSL Certificate invalid')
|
||||
|
||||
self.httpd.serve_forever(poll_interval=1)
|
||||
self.logger.info('Updater thread stopped')
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
|
@ -276,25 +266,49 @@ class Updater:
|
|||
self.job_queue.stop()
|
||||
with self.__lock:
|
||||
if self.running:
|
||||
self.running = False
|
||||
self.logger.info('Stopping Updater and Dispatcher...')
|
||||
self.logger.debug('This might take a long time if you set a '
|
||||
'high value as polling timeout.')
|
||||
|
||||
if self.httpd:
|
||||
self.logger.info(
|
||||
'Waiting for current webhook connection to be '
|
||||
'closed... Send a Telegram message to the bot to exit '
|
||||
'immediately.')
|
||||
self.httpd.shutdown()
|
||||
self.httpd = None
|
||||
self.running = False
|
||||
|
||||
self.logger.debug("Requesting Dispatcher to stop...")
|
||||
self.dispatcher.stop()
|
||||
while dispatcher.running_async > 0:
|
||||
sleep(1)
|
||||
self._stop_httpd()
|
||||
self._stop_dispatcher()
|
||||
self._join_threads()
|
||||
# async threads must be join()ed only after the dispatcher
|
||||
# thread was joined, otherwise we can still have new async
|
||||
# threads dispatched
|
||||
self._join_async_threads()
|
||||
|
||||
self.logger.debug("Dispatcher stopped.")
|
||||
def _stop_httpd(self):
|
||||
if self.httpd:
|
||||
self.logger.info(
|
||||
'Waiting for current webhook connection to be '
|
||||
'closed... Send a Telegram message to the bot to exit '
|
||||
'immediately.')
|
||||
self.httpd.shutdown()
|
||||
self.httpd = None
|
||||
|
||||
def _stop_dispatcher(self):
|
||||
self.logger.debug("Requesting Dispatcher to stop...")
|
||||
self.dispatcher.stop()
|
||||
|
||||
def _join_async_threads(self):
|
||||
with dispatcher.async_lock:
|
||||
threads = list(dispatcher.async_threads)
|
||||
total = len(threads)
|
||||
for i, thr in enumerate(threads):
|
||||
self.logger.info(
|
||||
'Waiting for async thread {0}/{1} to end'.format(i, total))
|
||||
thr.join()
|
||||
self.logger.debug(
|
||||
'async thread {0}/{1} has ended'.format(i, total))
|
||||
|
||||
def _join_threads(self):
|
||||
for thr in self.__threads:
|
||||
self.logger.info(
|
||||
'Waiting for {0} thread to end'.format(thr.name))
|
||||
thr.join()
|
||||
self.logger.debug('{0} thread has ended'.format(thr.name))
|
||||
self.__threads = []
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
self.is_idle = False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
from telegram import Update, NullHandler
|
||||
from future.utils import bytes_to_native_str as n
|
||||
from future.utils import bytes_to_native_str
|
||||
from threading import Lock
|
||||
import json
|
||||
try:
|
||||
|
@ -14,6 +14,13 @@ H = NullHandler()
|
|||
logging.getLogger(__name__).addHandler(H)
|
||||
|
||||
|
||||
class _InvalidPost(Exception):
|
||||
|
||||
def __init__(self, http_code):
|
||||
self.http_code = http_code
|
||||
super(_InvalidPost, self).__init__()
|
||||
|
||||
|
||||
class WebhookServer(BaseHTTPServer.HTTPServer, object):
|
||||
def __init__(self, server_address, RequestHandlerClass, update_queue,
|
||||
webhook_path):
|
||||
|
@ -63,12 +70,15 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
|
|||
|
||||
def do_POST(self):
|
||||
self.logger.debug("Webhook triggered")
|
||||
if self.path == self.server.webhook_path and \
|
||||
'content-type' in self.headers and \
|
||||
'content-length' in self.headers and \
|
||||
self.headers['content-type'] == 'application/json':
|
||||
json_string = \
|
||||
n(self.rfile.read(int(self.headers['content-length'])))
|
||||
try:
|
||||
self._validate_post()
|
||||
clen = self._get_content_len()
|
||||
except _InvalidPost as e:
|
||||
self.send_error(e.http_code)
|
||||
self.end_headers()
|
||||
else:
|
||||
buf = self.rfile.read(clen)
|
||||
json_string = bytes_to_native_str(buf)
|
||||
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
@ -80,6 +90,20 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
|
|||
update.update_id)
|
||||
self.server.update_queue.put(update)
|
||||
|
||||
else:
|
||||
self.send_error(403)
|
||||
self.end_headers()
|
||||
def _validate_post(self):
|
||||
if not (self.path == self.server.webhook_path and
|
||||
'content-type' in self.headers and
|
||||
self.headers['content-type'] == 'application/json'):
|
||||
raise _InvalidPost(403)
|
||||
|
||||
def _get_content_len(self):
|
||||
clen = self.headers.get('content-length')
|
||||
if clen is None:
|
||||
raise _InvalidPost(411)
|
||||
try:
|
||||
clen = int(clen)
|
||||
except ValueError:
|
||||
raise _InvalidPost(403)
|
||||
if clen < 0:
|
||||
raise _InvalidPost(403)
|
||||
return clen
|
||||
|
|
|
@ -30,6 +30,7 @@ import signal
|
|||
from random import randrange
|
||||
from time import sleep
|
||||
from datetime import datetime
|
||||
from future.builtins import bytes
|
||||
|
||||
if sys.version_info[0:2] == (2, 6):
|
||||
import unittest2 as unittest
|
||||
|
@ -37,9 +38,12 @@ else:
|
|||
import unittest
|
||||
|
||||
try:
|
||||
from urllib2 import urlopen, Request
|
||||
# python2
|
||||
from urllib2 import urlopen, Request, HTTPError
|
||||
except ImportError:
|
||||
# python3
|
||||
from urllib.request import Request, urlopen
|
||||
from urllib.error import HTTPError
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
|
@ -399,9 +403,9 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
|||
d.addTelegramMessageHandler(
|
||||
self.telegramHandlerTest)
|
||||
|
||||
# Select random port for travis
|
||||
port = randrange(1024, 49152)
|
||||
self.updater.start_webhook('127.0.0.1', port,
|
||||
ip = '127.0.0.1'
|
||||
port = randrange(1024, 49152) # Select random port for travis
|
||||
self.updater.start_webhook(ip, port,
|
||||
url_path='TOKEN',
|
||||
cert='./tests/test_updater.py',
|
||||
key='./tests/test_updater.py')
|
||||
|
@ -417,34 +421,19 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
|||
update = Update(1)
|
||||
update.message = message
|
||||
|
||||
try:
|
||||
payload = bytes(update.to_json(), encoding='utf-8')
|
||||
except TypeError:
|
||||
payload = bytes(update.to_json())
|
||||
|
||||
header = {
|
||||
'content-type': 'application/json',
|
||||
'content-length': str(len(payload))
|
||||
}
|
||||
|
||||
r = Request('http://127.0.0.1:%d/TOKEN' % port,
|
||||
data=payload,
|
||||
headers=header)
|
||||
|
||||
urlopen(r)
|
||||
self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN')
|
||||
|
||||
sleep(1)
|
||||
self.assertEqual(self.received_message, 'Webhook Test')
|
||||
|
||||
print("Test other webhook server functionalities...")
|
||||
request = Request('http://localhost:%d/webookhandler.py' % port)
|
||||
response = urlopen(request)
|
||||
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py')
|
||||
self.assertEqual(b'', response.read())
|
||||
self.assertEqual(200, response.code)
|
||||
|
||||
request.get_method = lambda: 'HEAD'
|
||||
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py',
|
||||
get_method=lambda: 'HEAD')
|
||||
|
||||
response = urlopen(request)
|
||||
self.assertEqual(b'', response.read())
|
||||
self.assertEqual(200, response.code)
|
||||
|
||||
|
@ -460,9 +449,9 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
|||
d.addTelegramMessageHandler(
|
||||
self.telegramHandlerTest)
|
||||
|
||||
# Select random port for travis
|
||||
port = randrange(1024, 49152)
|
||||
self.updater.start_webhook('127.0.0.1', port)
|
||||
ip = '127.0.0.1'
|
||||
port = randrange(1024, 49152) # Select random port for travis
|
||||
self.updater.start_webhook(ip, port)
|
||||
sleep(0.5)
|
||||
|
||||
# Now, we send an update to the server via urlopen
|
||||
|
@ -473,24 +462,78 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
|||
update = Update(1)
|
||||
update.message = message
|
||||
|
||||
try:
|
||||
payload = bytes(update.to_json(), encoding='utf-8')
|
||||
except TypeError:
|
||||
payload = bytes(update.to_json())
|
||||
|
||||
header = {
|
||||
'content-type': 'application/json',
|
||||
'content-length': str(len(payload))
|
||||
}
|
||||
|
||||
r = Request('http://127.0.0.1:%d/' % port,
|
||||
data=payload,
|
||||
headers=header)
|
||||
|
||||
urlopen(r)
|
||||
self._send_webhook_msg(ip, port, update.to_json())
|
||||
sleep(1)
|
||||
self.assertEqual(self.received_message, 'Webhook Test 2')
|
||||
|
||||
def test_webhook_invalid_posts(self):
|
||||
self._setup_updater('', messages=0)
|
||||
|
||||
ip = '127.0.0.1'
|
||||
port = randrange(1024, 49152) # select random port for travis
|
||||
thr = Thread(target=self.updater._start_webhook,
|
||||
args=(ip, port, '', None, None))
|
||||
thr.start()
|
||||
|
||||
sleep(0.5)
|
||||
|
||||
try:
|
||||
with self.assertRaises(HTTPError) as ctx:
|
||||
self._send_webhook_msg(ip, port,
|
||||
'<root><bla>data</bla></root>',
|
||||
content_type='application/xml')
|
||||
self.assertEqual(ctx.exception.code, 403)
|
||||
|
||||
with self.assertRaises(HTTPError) as ctx:
|
||||
self._send_webhook_msg(ip, port, 'dummy-payload',
|
||||
content_len=-2)
|
||||
self.assertEqual(ctx.exception.code, 403)
|
||||
|
||||
# TODO: prevent urllib or the underlying from adding content-length
|
||||
# with self.assertRaises(HTTPError) as ctx:
|
||||
# self._send_webhook_msg(ip, port, 'dummy-payload',
|
||||
# content_len=None)
|
||||
# self.assertEqual(ctx.exception.code, 411)
|
||||
|
||||
with self.assertRaises(HTTPError) as ctx:
|
||||
self._send_webhook_msg(ip, port, 'dummy-payload',
|
||||
content_len='not-a-number')
|
||||
self.assertEqual(ctx.exception.code, 403)
|
||||
|
||||
finally:
|
||||
self.updater._stop_httpd()
|
||||
thr.join()
|
||||
|
||||
def _send_webhook_msg(self, ip, port, payload_str, url_path='',
|
||||
content_len=-1, content_type='application/json',
|
||||
get_method=None):
|
||||
headers = {
|
||||
'content-type': content_type,
|
||||
}
|
||||
|
||||
if not payload_str:
|
||||
content_len = None
|
||||
payload = None
|
||||
else:
|
||||
payload = bytes(payload_str, encoding='utf-8')
|
||||
|
||||
if content_len == -1:
|
||||
content_len = len(payload)
|
||||
|
||||
if content_len is not None:
|
||||
headers['content-length'] = str(content_len)
|
||||
|
||||
url = 'http://{ip}:{port}/{path}'.format(ip=ip, port=port,
|
||||
path=url_path)
|
||||
|
||||
req = Request(url, data=payload, headers=headers)
|
||||
|
||||
|
||||
if get_method is not None:
|
||||
req.get_method = get_method
|
||||
|
||||
return urlopen(req)
|
||||
|
||||
def signalsender(self):
|
||||
sleep(0.5)
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
|
Loading…
Reference in a new issue