diff --git a/requirements.txt b/requirements.txt index 2c6edea8d..d2fc6a44a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -future +future>=0.15.2 +urllib3>=1.8.3 +certifi diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index fd10639ff..fc7e6d5d4 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -20,24 +20,47 @@ import logging from functools import wraps -from threading import Thread, BoundedSemaphore, Lock, Event, current_thread +from threading import Thread, Lock, Event, current_thread from time import sleep +from queue import Queue, Empty -from queue import Empty +from future.builtins import range from telegram import (TelegramError, NullHandler) +from telegram.utils import request from telegram.ext.handler import Handler from telegram.utils.deprecate import deprecate logging.getLogger(__name__).addHandler(NullHandler()) -semaphore = None -async_threads = set() +ASYNC_QUEUE = Queue() +ASYNC_THREADS = set() """:type: set[Thread]""" -async_lock = Lock() +ASYNC_LOCK = Lock() # guards ASYNC_THREADS DEFAULT_GROUP = 0 +def _pooled(): + """ + A wrapper to run a thread in a thread pool + """ + while 1: + try: + func, args, kwargs = ASYNC_QUEUE.get() + + # If unpacking fails, the thread pool is being closed from Updater._join_async_threads + except TypeError: + logging.getLogger(__name__).debug("Closing run_async thread %s/%d" % + (current_thread().getName(), len(ASYNC_THREADS))) + break + + try: + func(*args, **kwargs) + + except: + logging.getLogger(__name__).exception("run_async function raised exception") + + def run_async(func): """ Function decorator that will run the function in a new thread. @@ -53,30 +76,11 @@ def run_async(func): # set a threading.Event to notify caller thread @wraps(func) - def pooled(*pargs, **kwargs): - """ - A wrapper to run a thread in a thread pool - """ - try: - result = func(*pargs, **kwargs) - finally: - semaphore.release() - - with async_lock: - async_threads.remove(current_thread()) - return result - - @wraps(func) - def async_func(*pargs, **kwargs): + def async_func(*args, **kwargs): """ A wrapper to run a function in a thread """ - thread = Thread(target=pooled, args=pargs, kwargs=kwargs) - semaphore.acquire() - with async_lock: - async_threads.add(thread) - thread.start() - return thread + ASYNC_QUEUE.put((func, args, kwargs)) return async_func @@ -112,11 +116,18 @@ class Dispatcher(object): self.__stop_event = Event() self.__exception_event = exception_event or Event() - global semaphore - if not semaphore: - semaphore = BoundedSemaphore(value=workers) - else: - self.logger.debug('Semaphore already initialized, skipping.') + with ASYNC_LOCK: + if not ASYNC_THREADS: + if request.is_con_pool_initialized(): + raise RuntimeError('Connection Pool already initialized') + + request.CON_POOL_SIZE = workers + 3 + for i in range(workers): + thread = Thread(target=_pooled, name=str(i)) + ASYNC_THREADS.add(thread) + thread.start() + else: + self.logger.debug('Thread pool already initialized, skipping.') def start(self): """ @@ -136,7 +147,7 @@ class Dispatcher(object): self.running = True self.logger.debug('Dispatcher started') - while True: + while 1: try: # Pop update from update queue. update = self.update_queue.get(True, 1) @@ -150,7 +161,7 @@ class Dispatcher(object): continue self.logger.debug('Processing Update: %s' % update) - self.processUpdate(update) + self.process_update(update) self.running = False self.logger.debug('Dispatcher thread stopped') @@ -165,7 +176,7 @@ class Dispatcher(object): sleep(0.1) self.__stop_event.clear() - def processUpdate(self, update): + def process_update(self, update): """ Processes a single update. @@ -175,7 +186,7 @@ class Dispatcher(object): # An error happened while polling if isinstance(update, TelegramError): - self.dispatchError(None, update) + self.dispatch_error(None, update) else: for group in self.groups: @@ -190,7 +201,7 @@ class Dispatcher(object): 'Update.') try: - self.dispatchError(update, te) + self.dispatch_error(update, te) except Exception: self.logger.exception('An uncaught error was raised while ' 'handling the error') @@ -276,7 +287,7 @@ class Dispatcher(object): if callback in self.error_handlers: self.error_handlers.remove(callback) - def dispatchError(self, update, error): + def dispatch_error(self, update, error): """ Dispatches an error. diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index 06e76d453..83c79e11b 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -308,7 +308,7 @@ class Updater(object): def _bootstrap(self, max_retries, clean, webhook_url, cert=None): retries = 0 - while True: + while 1: try: if clean: @@ -345,7 +345,7 @@ class Updater(object): self.job_queue.stop() with self.__lock: - if self.running: + if self.running or dispatcher.ASYNC_THREADS: self.logger.debug('Stopping Updater and Dispatcher...') self.running = False @@ -353,9 +353,8 @@ class Updater(object): self._stop_httpd() self._stop_dispatcher() self._join_threads() - # async threads must be join()ed only after the dispatcher - # thread was joined, otherwise we can still have new async - # threads dispatched + # async threads must be join()ed only after the dispatcher thread was joined, + # otherwise we can still have new async threads dispatched self._join_async_threads() def _stop_httpd(self): @@ -371,13 +370,19 @@ class Updater(object): self.dispatcher.stop() def _join_async_threads(self): - with dispatcher.async_lock: - threads = list(dispatcher.async_threads) - total = len(threads) - for i, thr in enumerate(threads): - self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i, total)) - thr.join() - self.logger.debug('async thread {0}/{1} has ended'.format(i, total)) + with dispatcher.ASYNC_LOCK: + threads = list(dispatcher.ASYNC_THREADS) + total = len(threads) + + # Stop all threads in the thread pool by put()ting one non-tuple per thread + for i in range(total): + dispatcher.ASYNC_QUEUE.put(None) + + for i, thr in enumerate(threads): + self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total)) + thr.join() + dispatcher.ASYNC_THREADS.remove(thr) + self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total)) def _join_threads(self): for thr in self.__threads: diff --git a/telegram/utils/request.py b/telegram/utils/request.py index 52fe434d7..f801c54a3 100644 --- a/telegram/utils/request.py +++ b/telegram/utils/request.py @@ -18,29 +18,56 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains methods to make POST and GET requests""" -import functools import json import socket -from ssl import SSLError +import logging -from future.moves.http.client import HTTPException -from future.moves.urllib.error import HTTPError, URLError -from future.moves.urllib.request import urlopen, urlretrieve, Request +import certifi +import urllib3 +from urllib3.connection import HTTPConnection from telegram import (InputFile, TelegramError) from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest +_CON_POOL = None +""":type: urllib3.PoolManager""" +CON_POOL_SIZE = 1 + +logging.getLogger('urllib3').setLevel(logging.WARNING) + + +def _get_con_pool(): + global _CON_POOL + + if _CON_POOL is not None: + return _CON_POOL + + _CON_POOL = urllib3.PoolManager(maxsize=CON_POOL_SIZE, + cert_reqs='CERT_REQUIRED', + ca_certs=certifi.where(), + socket_options=HTTPConnection.default_socket_options + [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + ]) + return _CON_POOL + + +def is_con_pool_initialized(): + return _CON_POOL is not None + + +def stop_con_pool(): + global _CON_POOL + if _CON_POOL is not None: + _CON_POOL.clear() + _CON_POOL = None + def _parse(json_data): - """Try and parse the JSON returned from Telegram and return an empty - dictionary if there is any error. - - Args: - url: - urllib.urlopen object + """Try and parse the JSON returned from Telegram. Returns: - A JSON parsed as Python dict with results. + dict: A JSON parsed as Python dict with results - on error this dict will be empty. + """ decoded_s = json_data.decode('utf-8') try: @@ -54,53 +81,49 @@ def _parse(json_data): return data['result'] -def _try_except_req(func): - """Decorator for requests to handle known exceptions""" +def _request_wrapper(*args, **kwargs): + """Wraps urllib3 request for handling known exceptions. - @functools.wraps(func) - def decorator(*args, **kwargs): - try: - return func(*args, **kwargs) + Args: + args: unnamed arguments, passed to urllib3 request. + kwargs: keyword arguments, passed tp urllib3 request. - except HTTPError as error: - # `HTTPError` inherits from `URLError` so `HTTPError` handling must - # come first. - errcode = error.getcode() + Returns: + str: A non-parsed JSON text. - try: - message = _parse(error.read()) + Raises: + TelegramError - if errcode in (401, 403): - raise Unauthorized() - elif errcode == 400: - raise BadRequest(message) - elif errcode == 502: - raise NetworkError('Bad Gateway') - except ValueError: - message = 'Unknown HTTPError {0}'.format(error.getcode()) + """ - raise NetworkError('{0} ({1})'.format(message, errcode)) + try: + resp = _get_con_pool().request(*args, **kwargs) + except urllib3.exceptions.TimeoutError as error: + raise TimedOut() + except urllib3.exceptions.HTTPError as error: + # HTTPError must come last as its the base urllib3 exception class + # TODO: do something smart here; for now just raise NetworkError + raise NetworkError('urllib3 HTTPError {0}'.format(error)) - except URLError as error: - raise NetworkError('URLError: {0}'.format(error.reason)) + if 200 <= resp.status <= 299: + # 200-299 range are HTTP success statuses + return resp.data - except (SSLError, socket.timeout) as error: - err_s = str(error) - if 'operation timed out' in err_s: - raise TimedOut() + try: + message = _parse(resp.data) + except ValueError: + raise NetworkError('Unknown HTTPError {0}'.format(resp.status)) - raise NetworkError(err_s) - - except HTTPException as error: - raise NetworkError('HTTPException: {0!r}'.format(error)) - - except socket.error as error: - raise NetworkError('socket.error: {0!r}'.format(error)) - - return decorator + if resp.status in (401, 403): + raise Unauthorized() + elif resp.status == 400: + raise BadRequest(repr(message)) + elif resp.status == 502: + raise NetworkError('Bad Gateway') + else: + raise NetworkError('{0} ({1})'.format(message, resp.status)) -@_try_except_req def get(url): """Request an URL. Args: @@ -109,13 +132,13 @@ def get(url): Returns: A JSON object. + """ - result = urlopen(url).read() + result = _request_wrapper('GET', url) return _parse(result) -@_try_except_req def post(url, data, timeout=None): """Request an URL. Args: @@ -142,16 +165,17 @@ def post(url, data, timeout=None): if InputFile.is_inputfile(data): data = InputFile(data) - request = Request(url, data=data.to_form(), headers=data.headers) + result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers) else: data = json.dumps(data) - request = Request(url, data=data.encode(), headers={'Content-Type': 'application/json'}) + result = _request_wrapper('POST', + url, + body=data.encode(), + headers={'Content-Type': 'application/json'}) - result = urlopen(request, **urlopen_kwargs).read() return _parse(result) -@_try_except_req def download(url, filename): """Download a file by its URL. Args: @@ -160,6 +184,8 @@ def download(url, filename): filename: The filename within the path to download the file. - """ - urlretrieve(url, filename) + """ + buf = _request_wrapper('GET', url) + with open(filename, 'wb') as fobj: + fobj.write(buf) diff --git a/tests/test_bot.py b/tests/test_bot.py index db26e330f..5b59b990f 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -20,6 +20,7 @@ """This module contains a object that represents Tests for Telegram Bot""" import io +import re from datetime import datetime import sys @@ -211,10 +212,11 @@ class BotTest(BaseTest, unittest.TestCase): @flaky(3, 1) @timeout(10) def testLeaveChat(self): - with self.assertRaisesRegexp(telegram.error.BadRequest, 'Chat not found'): + regex = re.compile('chat not found', re.IGNORECASE) + with self.assertRaisesRegexp(telegram.error.BadRequest, regex): chat = self._bot.leaveChat(-123456) - with self.assertRaisesRegexp(telegram.error.NetworkError, 'Chat not found'): + with self.assertRaisesRegexp(telegram.error.NetworkError, regex): chat = self._bot.leaveChat(-123456) @flaky(3, 1) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index aef9d6723..254668670 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -31,6 +31,7 @@ else: sys.path.append('.') +from telegram.utils.request import stop_con_pool from telegram.ext import JobQueue, Job, Updater from tests.base import BaseTest @@ -58,6 +59,7 @@ class JobQueueTest(BaseTest, unittest.TestCase): def tearDown(self): if self.jq is not None: self.jq.stop() + stop_con_pool() def job1(self, bot, job): self.result += 1 diff --git a/tests/test_updater.py b/tests/test_updater.py index f40e27951..c5d606495 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -48,6 +48,7 @@ except ImportError: sys.path.append('.') from telegram import Update, Message, TelegramError, User, Chat, Bot +from telegram.utils.request import stop_con_pool from telegram.ext import * from telegram.ext.dispatcher import run_async from telegram.error import Unauthorized, InvalidToken @@ -83,12 +84,14 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.lock = Lock() def _setup_updater(self, *args, **kwargs): + stop_con_pool() bot = MockBot(*args, **kwargs) self.updater = Updater(workers=2, bot=bot) def tearDown(self): if self.updater is not None: self.updater.stop() + stop_con_pool() def reset(self): self.message_count = 0 @@ -648,8 +651,8 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.assertFalse(self.updater.running) def test_createBot(self): - updater = Updater('123:abcd') - self.assertIsNotNone(updater.bot) + self.updater = Updater('123:abcd') + self.assertIsNotNone(self.updater.bot) def test_mutualExclusiveTokenBot(self): bot = Bot('123:zyxw')