Merge branch 'master' into jobqueue-rework

Conflicts:
	tests/test_jobqueue.py
This commit is contained in:
Jannes Höke 2016-06-20 05:32:15 +02:00
commit c4a8ee5175
7 changed files with 162 additions and 111 deletions

View file

@ -1 +1,3 @@
future future>=0.15.2
urllib3>=1.8.3
certifi

View file

@ -20,24 +20,47 @@
import logging import logging
from functools import wraps from functools import wraps
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread from threading import Thread, Lock, Event, current_thread
from time import sleep from time import sleep
from queue import Queue, Empty
from queue import Empty from future.builtins import range
from telegram import (TelegramError, NullHandler) from telegram import (TelegramError, NullHandler)
from telegram.utils import request
from telegram.ext.handler import Handler from telegram.ext.handler import Handler
from telegram.utils.deprecate import deprecate from telegram.utils.deprecate import deprecate
logging.getLogger(__name__).addHandler(NullHandler()) logging.getLogger(__name__).addHandler(NullHandler())
semaphore = None ASYNC_QUEUE = Queue()
async_threads = set() ASYNC_THREADS = set()
""":type: set[Thread]""" """:type: set[Thread]"""
async_lock = Lock() ASYNC_LOCK = Lock() # guards ASYNC_THREADS
DEFAULT_GROUP = 0 DEFAULT_GROUP = 0
def _pooled():
"""
A wrapper to run a thread in a thread pool
"""
while 1:
try:
func, args, kwargs = ASYNC_QUEUE.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
except TypeError:
logging.getLogger(__name__).debug("Closing run_async thread %s/%d" %
(current_thread().getName(), len(ASYNC_THREADS)))
break
try:
func(*args, **kwargs)
except:
logging.getLogger(__name__).exception("run_async function raised exception")
def run_async(func): def run_async(func):
""" """
Function decorator that will run the function in a new thread. Function decorator that will run the function in a new thread.
@ -53,30 +76,11 @@ def run_async(func):
# set a threading.Event to notify caller thread # set a threading.Event to notify caller thread
@wraps(func) @wraps(func)
def pooled(*pargs, **kwargs): def async_func(*args, **kwargs):
"""
A wrapper to run a thread in a thread pool
"""
try:
result = func(*pargs, **kwargs)
finally:
semaphore.release()
with async_lock:
async_threads.remove(current_thread())
return result
@wraps(func)
def async_func(*pargs, **kwargs):
""" """
A wrapper to run a function in a thread A wrapper to run a function in a thread
""" """
thread = Thread(target=pooled, args=pargs, kwargs=kwargs) ASYNC_QUEUE.put((func, args, kwargs))
semaphore.acquire()
with async_lock:
async_threads.add(thread)
thread.start()
return thread
return async_func return async_func
@ -112,11 +116,18 @@ class Dispatcher(object):
self.__stop_event = Event() self.__stop_event = Event()
self.__exception_event = exception_event or Event() self.__exception_event = exception_event or Event()
global semaphore with ASYNC_LOCK:
if not semaphore: if not ASYNC_THREADS:
semaphore = BoundedSemaphore(value=workers) if request.is_con_pool_initialized():
raise RuntimeError('Connection Pool already initialized')
request.CON_POOL_SIZE = workers + 3
for i in range(workers):
thread = Thread(target=_pooled, name=str(i))
ASYNC_THREADS.add(thread)
thread.start()
else: else:
self.logger.debug('Semaphore already initialized, skipping.') self.logger.debug('Thread pool already initialized, skipping.')
def start(self): def start(self):
""" """
@ -136,7 +147,7 @@ class Dispatcher(object):
self.running = True self.running = True
self.logger.debug('Dispatcher started') self.logger.debug('Dispatcher started')
while True: while 1:
try: try:
# Pop update from update queue. # Pop update from update queue.
update = self.update_queue.get(True, 1) update = self.update_queue.get(True, 1)
@ -150,7 +161,7 @@ class Dispatcher(object):
continue continue
self.logger.debug('Processing Update: %s' % update) self.logger.debug('Processing Update: %s' % update)
self.processUpdate(update) self.process_update(update)
self.running = False self.running = False
self.logger.debug('Dispatcher thread stopped') self.logger.debug('Dispatcher thread stopped')
@ -165,7 +176,7 @@ class Dispatcher(object):
sleep(0.1) sleep(0.1)
self.__stop_event.clear() self.__stop_event.clear()
def processUpdate(self, update): def process_update(self, update):
""" """
Processes a single update. Processes a single update.
@ -175,7 +186,7 @@ class Dispatcher(object):
# An error happened while polling # An error happened while polling
if isinstance(update, TelegramError): if isinstance(update, TelegramError):
self.dispatchError(None, update) self.dispatch_error(None, update)
else: else:
for group in self.groups: for group in self.groups:
@ -190,7 +201,7 @@ class Dispatcher(object):
'Update.') 'Update.')
try: try:
self.dispatchError(update, te) self.dispatch_error(update, te)
except Exception: except Exception:
self.logger.exception('An uncaught error was raised while ' self.logger.exception('An uncaught error was raised while '
'handling the error') 'handling the error')
@ -276,7 +287,7 @@ class Dispatcher(object):
if callback in self.error_handlers: if callback in self.error_handlers:
self.error_handlers.remove(callback) self.error_handlers.remove(callback)
def dispatchError(self, update, error): def dispatch_error(self, update, error):
""" """
Dispatches an error. Dispatches an error.

View file

@ -308,7 +308,7 @@ class Updater(object):
def _bootstrap(self, max_retries, clean, webhook_url, cert=None): def _bootstrap(self, max_retries, clean, webhook_url, cert=None):
retries = 0 retries = 0
while True: while 1:
try: try:
if clean: if clean:
@ -345,7 +345,7 @@ class Updater(object):
self.job_queue.stop() self.job_queue.stop()
with self.__lock: with self.__lock:
if self.running: if self.running or dispatcher.ASYNC_THREADS:
self.logger.debug('Stopping Updater and Dispatcher...') self.logger.debug('Stopping Updater and Dispatcher...')
self.running = False self.running = False
@ -353,9 +353,8 @@ class Updater(object):
self._stop_httpd() self._stop_httpd()
self._stop_dispatcher() self._stop_dispatcher()
self._join_threads() self._join_threads()
# async threads must be join()ed only after the dispatcher # async threads must be join()ed only after the dispatcher thread was joined,
# thread was joined, otherwise we can still have new async # otherwise we can still have new async threads dispatched
# threads dispatched
self._join_async_threads() self._join_async_threads()
def _stop_httpd(self): def _stop_httpd(self):
@ -371,13 +370,19 @@ class Updater(object):
self.dispatcher.stop() self.dispatcher.stop()
def _join_async_threads(self): def _join_async_threads(self):
with dispatcher.async_lock: with dispatcher.ASYNC_LOCK:
threads = list(dispatcher.async_threads) threads = list(dispatcher.ASYNC_THREADS)
total = len(threads) total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
dispatcher.ASYNC_QUEUE.put(None)
for i, thr in enumerate(threads): for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i, total)) self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join() thr.join()
self.logger.debug('async thread {0}/{1} has ended'.format(i, total)) dispatcher.ASYNC_THREADS.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
def _join_threads(self): def _join_threads(self):
for thr in self.__threads: for thr in self.__threads:

View file

@ -18,29 +18,56 @@
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains methods to make POST and GET requests""" """This module contains methods to make POST and GET requests"""
import functools
import json import json
import socket import socket
from ssl import SSLError import logging
from future.moves.http.client import HTTPException import certifi
from future.moves.urllib.error import HTTPError, URLError import urllib3
from future.moves.urllib.request import urlopen, urlretrieve, Request from urllib3.connection import HTTPConnection
from telegram import (InputFile, TelegramError) from telegram import (InputFile, TelegramError)
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest
_CON_POOL = None
""":type: urllib3.PoolManager"""
CON_POOL_SIZE = 1
logging.getLogger('urllib3').setLevel(logging.WARNING)
def _get_con_pool():
global _CON_POOL
if _CON_POOL is not None:
return _CON_POOL
_CON_POOL = urllib3.PoolManager(maxsize=CON_POOL_SIZE,
cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
return _CON_POOL
def is_con_pool_initialized():
return _CON_POOL is not None
def stop_con_pool():
global _CON_POOL
if _CON_POOL is not None:
_CON_POOL.clear()
_CON_POOL = None
def _parse(json_data): def _parse(json_data):
"""Try and parse the JSON returned from Telegram and return an empty """Try and parse the JSON returned from Telegram.
dictionary if there is any error.
Args:
url:
urllib.urlopen object
Returns: Returns:
A JSON parsed as Python dict with results. dict: A JSON parsed as Python dict with results - on error this dict will be empty.
""" """
decoded_s = json_data.decode('utf-8') decoded_s = json_data.decode('utf-8')
try: try:
@ -54,53 +81,49 @@ def _parse(json_data):
return data['result'] return data['result']
def _try_except_req(func): def _request_wrapper(*args, **kwargs):
"""Decorator for requests to handle known exceptions""" """Wraps urllib3 request for handling known exceptions.
@functools.wraps(func) Args:
def decorator(*args, **kwargs): args: unnamed arguments, passed to urllib3 request.
try: kwargs: keyword arguments, passed tp urllib3 request.
return func(*args, **kwargs)
except HTTPError as error: Returns:
# `HTTPError` inherits from `URLError` so `HTTPError` handling must str: A non-parsed JSON text.
# come first.
errcode = error.getcode() Raises:
TelegramError
"""
try: try:
message = _parse(error.read()) resp = _get_con_pool().request(*args, **kwargs)
except urllib3.exceptions.TimeoutError as error:
if errcode in (401, 403):
raise Unauthorized()
elif errcode == 400:
raise BadRequest(message)
elif errcode == 502:
raise NetworkError('Bad Gateway')
except ValueError:
message = 'Unknown HTTPError {0}'.format(error.getcode())
raise NetworkError('{0} ({1})'.format(message, errcode))
except URLError as error:
raise NetworkError('URLError: {0}'.format(error.reason))
except (SSLError, socket.timeout) as error:
err_s = str(error)
if 'operation timed out' in err_s:
raise TimedOut() raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
raise NetworkError(err_s) if 200 <= resp.status <= 299:
# 200-299 range are HTTP success statuses
return resp.data
except HTTPException as error: try:
raise NetworkError('HTTPException: {0!r}'.format(error)) message = _parse(resp.data)
except ValueError:
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
except socket.error as error: if resp.status in (401, 403):
raise NetworkError('socket.error: {0!r}'.format(error)) raise Unauthorized()
elif resp.status == 400:
return decorator raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
@_try_except_req
def get(url): def get(url):
"""Request an URL. """Request an URL.
Args: Args:
@ -109,13 +132,13 @@ def get(url):
Returns: Returns:
A JSON object. A JSON object.
""" """
result = urlopen(url).read() result = _request_wrapper('GET', url)
return _parse(result) return _parse(result)
@_try_except_req
def post(url, data, timeout=None): def post(url, data, timeout=None):
"""Request an URL. """Request an URL.
Args: Args:
@ -142,16 +165,17 @@ def post(url, data, timeout=None):
if InputFile.is_inputfile(data): if InputFile.is_inputfile(data):
data = InputFile(data) data = InputFile(data)
request = Request(url, data=data.to_form(), headers=data.headers) result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else: else:
data = json.dumps(data) data = json.dumps(data)
request = Request(url, data=data.encode(), headers={'Content-Type': 'application/json'}) result = _request_wrapper('POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'})
result = urlopen(request, **urlopen_kwargs).read()
return _parse(result) return _parse(result)
@_try_except_req
def download(url, filename): def download(url, filename):
"""Download a file by its URL. """Download a file by its URL.
Args: Args:
@ -160,6 +184,8 @@ def download(url, filename):
filename: filename:
The filename within the path to download the file. The filename within the path to download the file.
"""
urlretrieve(url, filename) """
buf = _request_wrapper('GET', url)
with open(filename, 'wb') as fobj:
fobj.write(buf)

View file

@ -20,6 +20,7 @@
"""This module contains a object that represents Tests for Telegram Bot""" """This module contains a object that represents Tests for Telegram Bot"""
import io import io
import re
from datetime import datetime from datetime import datetime
import sys import sys
@ -211,10 +212,11 @@ class BotTest(BaseTest, unittest.TestCase):
@flaky(3, 1) @flaky(3, 1)
@timeout(10) @timeout(10)
def testLeaveChat(self): def testLeaveChat(self):
with self.assertRaisesRegexp(telegram.error.BadRequest, 'Chat not found'): regex = re.compile('chat not found', re.IGNORECASE)
with self.assertRaisesRegexp(telegram.error.BadRequest, regex):
chat = self._bot.leaveChat(-123456) chat = self._bot.leaveChat(-123456)
with self.assertRaisesRegexp(telegram.error.NetworkError, 'Chat not found'): with self.assertRaisesRegexp(telegram.error.NetworkError, regex):
chat = self._bot.leaveChat(-123456) chat = self._bot.leaveChat(-123456)
@flaky(3, 1) @flaky(3, 1)

View file

@ -31,6 +31,7 @@ else:
sys.path.append('.') sys.path.append('.')
from telegram.utils.request import stop_con_pool
from telegram.ext import JobQueue, Job, Updater from telegram.ext import JobQueue, Job, Updater
from tests.base import BaseTest from tests.base import BaseTest
@ -58,6 +59,7 @@ class JobQueueTest(BaseTest, unittest.TestCase):
def tearDown(self): def tearDown(self):
if self.jq is not None: if self.jq is not None:
self.jq.stop() self.jq.stop()
stop_con_pool()
def job1(self, bot, job): def job1(self, bot, job):
self.result += 1 self.result += 1

View file

@ -48,6 +48,7 @@ except ImportError:
sys.path.append('.') sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot from telegram import Update, Message, TelegramError, User, Chat, Bot
from telegram.utils.request import stop_con_pool
from telegram.ext import * from telegram.ext import *
from telegram.ext.dispatcher import run_async from telegram.ext.dispatcher import run_async
from telegram.error import Unauthorized, InvalidToken from telegram.error import Unauthorized, InvalidToken
@ -83,12 +84,14 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.lock = Lock() self.lock = Lock()
def _setup_updater(self, *args, **kwargs): def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs) bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot) self.updater = Updater(workers=2, bot=bot)
def tearDown(self): def tearDown(self):
if self.updater is not None: if self.updater is not None:
self.updater.stop() self.updater.stop()
stop_con_pool()
def reset(self): def reset(self):
self.message_count = 0 self.message_count = 0
@ -648,8 +651,8 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertFalse(self.updater.running) self.assertFalse(self.updater.running)
def test_createBot(self): def test_createBot(self):
updater = Updater('123:abcd') self.updater = Updater('123:abcd')
self.assertIsNotNone(updater.bot) self.assertIsNotNone(self.updater.bot)
def test_mutualExclusiveTokenBot(self): def test_mutualExclusiveTokenBot(self):
bot = Bot('123:zyxw') bot = Bot('123:zyxw')