Merge pull request #177 from tsnoam/master

join() threads for a cleaner stop procedure
This commit is contained in:
Noam Meltzer 2016-02-19 19:10:35 +02:00
commit 04c86813b3
4 changed files with 170 additions and 90 deletions

View file

@ -22,7 +22,7 @@
import logging
from functools import wraps
from inspect import getargspec
from threading import Thread, BoundedSemaphore, Lock, Event
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread
from re import match
from time import sleep
@ -33,7 +33,8 @@ H = NullHandler()
logging.getLogger(__name__).addHandler(H)
semaphore = None
running_async = 0
async_threads = set()
""":type: set[Thread]"""
async_lock = Lock()
@ -58,11 +59,10 @@ def run_async(func):
"""
A wrapper to run a thread in a thread pool
"""
global running_async, async_lock
result = func(*pargs, **kwargs)
semaphore.release()
with async_lock:
running_async -= 1
async_threads.remove(current_thread())
return result
@wraps(func)
@ -70,11 +70,10 @@ def run_async(func):
"""
A wrapper to run a function in a thread
"""
global running_async, async_lock
thread = Thread(target=pooled, args=pargs, kwargs=kwargs)
semaphore.acquire()
with async_lock:
running_async += 1
async_threads.add(thread)
thread.start()
return thread

View file

@ -89,6 +89,8 @@ class Updater:
self.is_idle = False
self.httpd = None
self.__lock = Lock()
self.__threads = []
""":type: list[Thread]"""
def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2):
"""
@ -120,9 +122,10 @@ class Updater:
thr = Thread(target=self._thread_wrapper, name=name,
args=(target,) + args, kwargs=kwargs)
thr.start()
self.__threads.append(thr)
def _thread_wrapper(self, target, *args, **kwargs):
thr_name = current_thread()
thr_name = current_thread().name
self.logger.debug('{0} - started'.format(thr_name))
try:
target(*args, **kwargs)
@ -160,20 +163,10 @@ class Updater:
if not self.running:
self.running = True
# Create Thread objects
dispatcher_thread = Thread(target=self.dispatcher.start,
name="dispatcher")
updater_thread = Thread(target=self._start_webhook,
name="updater",
args=(listen,
port,
url_path,
cert,
key))
# Start threads
dispatcher_thread.start()
updater_thread.start()
# Create & start threads
self._init_thread(self.dispatcher.start, "dispatcher"),
self._init_thread(self._start_webhook, "updater", listen,
port, url_path, cert, key)
# Return the update queue so the main thread can insert updates
return self.update_queue
@ -221,8 +214,6 @@ class Updater:
sleep(cur_interval)
self.logger.info('Updater thread stopped')
@staticmethod
def _increase_poll_interval(current_interval):
# increase waiting times on subsequent errors up to 30secs
@ -266,7 +257,6 @@ class Updater:
raise TelegramError('SSL Certificate invalid')
self.httpd.serve_forever(poll_interval=1)
self.logger.info('Updater thread stopped')
def stop(self):
"""
@ -276,25 +266,49 @@ class Updater:
self.job_queue.stop()
with self.__lock:
if self.running:
self.running = False
self.logger.info('Stopping Updater and Dispatcher...')
self.logger.debug('This might take a long time if you set a '
'high value as polling timeout.')
if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None
self.running = False
self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()
while dispatcher.running_async > 0:
sleep(1)
self._stop_httpd()
self._stop_dispatcher()
self._join_threads()
# async threads must be join()ed only after the dispatcher
# thread was joined, otherwise we can still have new async
# threads dispatched
self._join_async_threads()
self.logger.debug("Dispatcher stopped.")
def _stop_httpd(self):
if self.httpd:
self.logger.info(
'Waiting for current webhook connection to be '
'closed... Send a Telegram message to the bot to exit '
'immediately.')
self.httpd.shutdown()
self.httpd = None
def _stop_dispatcher(self):
self.logger.debug("Requesting Dispatcher to stop...")
self.dispatcher.stop()
def _join_async_threads(self):
with dispatcher.async_lock:
threads = list(dispatcher.async_threads)
total = len(threads)
for i, thr in enumerate(threads):
self.logger.info(
'Waiting for async thread {0}/{1} to end'.format(i, total))
thr.join()
self.logger.debug(
'async thread {0}/{1} has ended'.format(i, total))
def _join_threads(self):
for thr in self.__threads:
self.logger.info(
'Waiting for {0} thread to end'.format(thr.name))
thr.join()
self.logger.debug('{0} thread has ended'.format(thr.name))
self.__threads = []
def signal_handler(self, signum, frame):
self.is_idle = False

View file

@ -1,7 +1,7 @@
import logging
from telegram import Update, NullHandler
from future.utils import bytes_to_native_str as n
from future.utils import bytes_to_native_str
from threading import Lock
import json
try:
@ -14,6 +14,13 @@ H = NullHandler()
logging.getLogger(__name__).addHandler(H)
class _InvalidPost(Exception):
def __init__(self, http_code):
self.http_code = http_code
super(_InvalidPost, self).__init__()
class WebhookServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, server_address, RequestHandlerClass, update_queue,
webhook_path):
@ -63,12 +70,15 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
def do_POST(self):
self.logger.debug("Webhook triggered")
if self.path == self.server.webhook_path and \
'content-type' in self.headers and \
'content-length' in self.headers and \
self.headers['content-type'] == 'application/json':
json_string = \
n(self.rfile.read(int(self.headers['content-length'])))
try:
self._validate_post()
clen = self._get_content_len()
except _InvalidPost as e:
self.send_error(e.http_code)
self.end_headers()
else:
buf = self.rfile.read(clen)
json_string = bytes_to_native_str(buf)
self.send_response(200)
self.end_headers()
@ -80,6 +90,20 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
update.update_id)
self.server.update_queue.put(update)
else:
self.send_error(403)
self.end_headers()
def _validate_post(self):
if not (self.path == self.server.webhook_path and
'content-type' in self.headers and
self.headers['content-type'] == 'application/json'):
raise _InvalidPost(403)
def _get_content_len(self):
clen = self.headers.get('content-length')
if clen is None:
raise _InvalidPost(411)
try:
clen = int(clen)
except ValueError:
raise _InvalidPost(403)
if clen < 0:
raise _InvalidPost(403)
return clen

View file

@ -30,6 +30,7 @@ import signal
from random import randrange
from time import sleep
from datetime import datetime
from future.builtins import bytes
if sys.version_info[0:2] == (2, 6):
import unittest2 as unittest
@ -37,9 +38,12 @@ else:
import unittest
try:
from urllib2 import urlopen, Request
# python2
from urllib2 import urlopen, Request, HTTPError
except ImportError:
# python3
from urllib.request import Request, urlopen
from urllib.error import HTTPError
sys.path.append('.')
@ -399,9 +403,9 @@ class UpdaterTest(BaseTest, unittest.TestCase):
d.addTelegramMessageHandler(
self.telegramHandlerTest)
# Select random port for travis
port = randrange(1024, 49152)
self.updater.start_webhook('127.0.0.1', port,
ip = '127.0.0.1'
port = randrange(1024, 49152) # Select random port for travis
self.updater.start_webhook(ip, port,
url_path='TOKEN',
cert='./tests/test_updater.py',
key='./tests/test_updater.py')
@ -417,34 +421,19 @@ class UpdaterTest(BaseTest, unittest.TestCase):
update = Update(1)
update.message = message
try:
payload = bytes(update.to_json(), encoding='utf-8')
except TypeError:
payload = bytes(update.to_json())
header = {
'content-type': 'application/json',
'content-length': str(len(payload))
}
r = Request('http://127.0.0.1:%d/TOKEN' % port,
data=payload,
headers=header)
urlopen(r)
self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN')
sleep(1)
self.assertEqual(self.received_message, 'Webhook Test')
print("Test other webhook server functionalities...")
request = Request('http://localhost:%d/webookhandler.py' % port)
response = urlopen(request)
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py')
self.assertEqual(b'', response.read())
self.assertEqual(200, response.code)
request.get_method = lambda: 'HEAD'
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py',
get_method=lambda: 'HEAD')
response = urlopen(request)
self.assertEqual(b'', response.read())
self.assertEqual(200, response.code)
@ -460,9 +449,9 @@ class UpdaterTest(BaseTest, unittest.TestCase):
d.addTelegramMessageHandler(
self.telegramHandlerTest)
# Select random port for travis
port = randrange(1024, 49152)
self.updater.start_webhook('127.0.0.1', port)
ip = '127.0.0.1'
port = randrange(1024, 49152) # Select random port for travis
self.updater.start_webhook(ip, port)
sleep(0.5)
# Now, we send an update to the server via urlopen
@ -473,24 +462,78 @@ class UpdaterTest(BaseTest, unittest.TestCase):
update = Update(1)
update.message = message
try:
payload = bytes(update.to_json(), encoding='utf-8')
except TypeError:
payload = bytes(update.to_json())
header = {
'content-type': 'application/json',
'content-length': str(len(payload))
}
r = Request('http://127.0.0.1:%d/' % port,
data=payload,
headers=header)
urlopen(r)
self._send_webhook_msg(ip, port, update.to_json())
sleep(1)
self.assertEqual(self.received_message, 'Webhook Test 2')
def test_webhook_invalid_posts(self):
self._setup_updater('', messages=0)
ip = '127.0.0.1'
port = randrange(1024, 49152) # select random port for travis
thr = Thread(target=self.updater._start_webhook,
args=(ip, port, '', None, None))
thr.start()
sleep(0.5)
try:
with self.assertRaises(HTTPError) as ctx:
self._send_webhook_msg(ip, port,
'<root><bla>data</bla></root>',
content_type='application/xml')
self.assertEqual(ctx.exception.code, 403)
with self.assertRaises(HTTPError) as ctx:
self._send_webhook_msg(ip, port, 'dummy-payload',
content_len=-2)
self.assertEqual(ctx.exception.code, 403)
# TODO: prevent urllib or the underlying from adding content-length
# with self.assertRaises(HTTPError) as ctx:
# self._send_webhook_msg(ip, port, 'dummy-payload',
# content_len=None)
# self.assertEqual(ctx.exception.code, 411)
with self.assertRaises(HTTPError) as ctx:
self._send_webhook_msg(ip, port, 'dummy-payload',
content_len='not-a-number')
self.assertEqual(ctx.exception.code, 403)
finally:
self.updater._stop_httpd()
thr.join()
def _send_webhook_msg(self, ip, port, payload_str, url_path='',
content_len=-1, content_type='application/json',
get_method=None):
headers = {
'content-type': content_type,
}
if not payload_str:
content_len = None
payload = None
else:
payload = bytes(payload_str, encoding='utf-8')
if content_len == -1:
content_len = len(payload)
if content_len is not None:
headers['content-length'] = str(content_len)
url = 'http://{ip}:{port}/{path}'.format(ip=ip, port=port,
path=url_path)
req = Request(url, data=payload, headers=headers)
if get_method is not None:
req.get_method = get_method
return urlopen(req)
def signalsender(self):
sleep(0.5)
os.kill(os.getpid(), signal.SIGTERM)