From e4a132c0e41b7494c0c68da39895f6f6305e74c7 Mon Sep 17 00:00:00 2001 From: Noam Meltzer Date: Tue, 6 Sep 2016 16:38:07 +0300 Subject: [PATCH] Reusable dispatcher (#402) * Create a Request class which maintains its own connection pool * When creating a Bot instance a new Request instance will be created if one wasn't supplied. * Updater is responsible for creating a Request instance if a Bot instance wasn't provided. * Dispatcher: add method to run async functions without decorator * Dispatcher can now run as a singleton (allowing run_async decorator to work) as it always did and as multiple instances (where run_async decorator will raise RuntimeError) --- .pre-commit-config.yaml | 6 +- telegram/bot.py | 43 ++-- telegram/ext/dispatcher.py | 169 +++++++++----- telegram/ext/updater.py | 42 ++-- telegram/file.py | 17 +- telegram/utils/request.py | 355 +++++++++++++----------------- tests/test_bot.py | 4 +- tests/test_conversationhandler.py | 15 +- tests/test_file.py | 6 +- tests/test_jobqueue.py | 21 +- tests/test_updater.py | 73 +++++- 11 files changed, 420 insertions(+), 331 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f0591f9ac..95efc1065 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,15 @@ - repo: git://github.com/pre-commit/mirrors-yapf - sha: 34303f2856d4e4ba26dc302d9c28632e9b5a8626 + sha: v0.11.0 hooks: - id: yapf files: ^(telegram|tests)/.*\.py$ - repo: git://github.com/pre-commit/pre-commit-hooks - sha: 3fa02652357ff0dbb42b5bc78c673b7bc105fcf3 + sha: 18d7035de5388cc7775be57f529c154bf541aab9 hooks: - id: flake8 files: ^telegram/.*\.py$ - repo: git://github.com/pre-commit/mirrors-pylint - sha: 4de6c8dfadef1a271a814561ce05b8bc1c446d22 + sha: v1.5.5 hooks: - id: pylint files: ^telegram/.*\.py$ diff --git a/telegram/bot.py b/telegram/bot.py index 38091e885..a73ab1467 100644 --- a/telegram/bot.py +++ b/telegram/bot.py @@ -25,7 +25,7 @@ import logging from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File, ReplyMarkup, TelegramObject) from telegram.error import InvalidToken -from telegram.utils import request +from telegram.utils.request import Request logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -44,10 +44,11 @@ class Bot(TelegramObject): token (str): Bot's unique authentication. base_url (Optional[str]): Telegram Bot API service URL. base_file_url (Optional[str]): Telegram Bot API file URL. + request (Optional[Request]): Pre initialized `Request` class. """ - def __init__(self, token, base_url=None, base_file_url=None): + def __init__(self, token, base_url=None, base_file_url=None, request=None): self.token = self._validate_token(token) if not base_url: @@ -61,9 +62,13 @@ class Bot(TelegramObject): self.base_file_url = base_file_url + self.token self.bot = None - + self._request = request or Request() self.logger = logging.getLogger(__name__) + @property + def request(self): + return self._request + @staticmethod def _validate_token(token): """a very basic validation on token""" @@ -144,7 +149,7 @@ class Bot(TelegramObject): else: data['reply_markup'] = reply_markup - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) if result is True: return result @@ -169,7 +174,7 @@ class Bot(TelegramObject): url = '{0}/getMe'.format(self.base_url) - result = request.get(url) + result = self._request.get(url) self.bot = User.de_json(result) @@ -813,7 +818,7 @@ class Bot(TelegramObject): if switch_pm_parameter: data['switch_pm_parameter'] = switch_pm_parameter - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -853,7 +858,7 @@ class Bot(TelegramObject): if limit: data['limit'] = limit - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return UserProfilePhotos.de_json(result) @@ -884,12 +889,12 @@ class Bot(TelegramObject): data = {'file_id': file_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) if result.get('file_path'): result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path']) - return File.de_json(result) + return File.de_json(result, self._request) @log def kickChatMember(self, chat_id, user_id, **kwargs): @@ -921,7 +926,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id, 'user_id': user_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -955,7 +960,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id, 'user_id': user_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -999,7 +1004,7 @@ class Bot(TelegramObject): if show_alert: data['show_alert'] = show_alert - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -1213,7 +1218,7 @@ class Bot(TelegramObject): urlopen_timeout = timeout + network_delay - result = request.post(url, data, timeout=urlopen_timeout) + result = self._request.post(url, data, timeout=urlopen_timeout) if result: self.logger.debug('Getting updates: %s', [u['update_id'] for u in result]) @@ -1256,7 +1261,7 @@ class Bot(TelegramObject): if certificate: data['certificate'] = certificate - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -1286,7 +1291,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -1318,7 +1323,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return Chat.de_json(result) @@ -1353,7 +1358,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return [ChatMember.de_json(x) for x in result] @@ -1383,7 +1388,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return result @@ -1416,7 +1421,7 @@ class Bot(TelegramObject): data = {'chat_id': chat_id, 'user_id': user_id} - result = request.post(url, data, timeout=kwargs.get('timeout')) + result = self._request.post(url, data, timeout=kwargs.get('timeout')) return ChatMember.de_json(result) diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index c6693e53a..c80c68aec 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -19,70 +19,43 @@ """This module contains the Dispatcher class.""" import logging +import weakref from functools import wraps -from threading import Thread, Lock, Event, current_thread +from threading import Thread, Lock, Event, current_thread, BoundedSemaphore from time import sleep +from uuid import uuid4 + from queue import Queue, Empty from future.builtins import range from telegram import TelegramError -from telegram.utils import request from telegram.ext.handler import Handler from telegram.utils.deprecate import deprecate from telegram.utils.promise import Promise logging.getLogger(__name__).addHandler(logging.NullHandler()) - -ASYNC_QUEUE = Queue() -ASYNC_THREADS = set() """:type: set[Thread]""" -ASYNC_LOCK = Lock() # guards ASYNC_THREADS DEFAULT_GROUP = 0 -def _pooled(): - """ - A wrapper to run a thread in a thread pool - """ - while 1: - promise = ASYNC_QUEUE.get() - - # If unpacking fails, the thread pool is being closed from Updater._join_async_threads - if not isinstance(promise, Promise): - logging.getLogger(__name__).debug("Closing run_async thread %s/%d" % - (current_thread().getName(), len(ASYNC_THREADS))) - break - - try: - promise.run() - - except: - logging.getLogger(__name__).exception("run_async function raised exception") - - def run_async(func): - """ - Function decorator that will run the function in a new thread. + """Function decorator that will run the function in a new thread. + + Using this decorator is only possible when only a single Dispatcher exist in the system. Args: func (function): The function to run in the thread. + async_queue (Queue): The queue of the functions to be executed asynchronously. Returns: function: - """ - # TODO: handle exception in async threads - # set a threading.Event to notify caller thread + """ @wraps(func) def async_func(*args, **kwargs): - """ - A wrapper to run a function in a thread - """ - promise = Promise(func, args, kwargs) - ASYNC_QUEUE.put(promise) - return promise + return Dispatcher.get_instance().run_async(func, *args, **kwargs) return async_func @@ -100,7 +73,12 @@ class Dispatcher(object): callbacks workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async`` decorator + """ + __singleton_lock = Lock() + __singleton_semaphore = BoundedSemaphore() + __singleton = None + logger = logging.getLogger(__name__) def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None): self.bot = bot @@ -113,28 +91,92 @@ class Dispatcher(object): """:type: list[int]""" self.error_handlers = [] - self.logger = logging.getLogger(__name__) self.running = False self.__stop_event = Event() self.__exception_event = exception_event or Event() + self.__async_queue = Queue() + self.__async_threads = set() - with ASYNC_LOCK: - if not ASYNC_THREADS: - if request.is_con_pool_initialized(): - raise RuntimeError('Connection Pool already initialized') - - # we need a connection pool the size of: - # * for each of the workers - # * 1 for Dispatcher - # * 1 for polling Updater (even if updater is webhook, we can spare a connection) - # * 1 for JobQueue - request.CON_POOL_SIZE = workers + 3 - for i in range(workers): - thread = Thread(target=_pooled, name=str(i)) - ASYNC_THREADS.add(thread) - thread.start() + # For backward compatibility, we allow a "singleton" mode for the dispatcher. When there's + # only one instance of Dispatcher, it will be possible to use the `run_async` decorator. + with self.__singleton_lock: + if self.__singleton_semaphore.acquire(blocking=0): + self._set_singleton(self) else: - self.logger.debug('Thread pool already initialized, skipping.') + self._set_singleton(None) + + self._init_async_threads(uuid4(), workers) + + @classmethod + def _reset_singleton(cls): + # NOTE: This method was added mainly for test_updater benefit and specifically pypy. Never + # call it in production code. + cls.__singleton_semaphore.release() + + def _init_async_threads(self, base_name, workers): + base_name = '{}_'.format(base_name) if base_name else '' + + for i in range(workers): + thread = Thread(target=self._pooled, name='{}{}'.format(base_name, i)) + self.__async_threads.add(thread) + thread.start() + + @classmethod + def _set_singleton(cls, val): + cls.logger.debug('Setting singleton dispatcher as %s', val) + cls.__singleton = weakref.ref(val) if val else None + + @classmethod + def get_instance(cls): + """Get the singleton instance of this class. + + Returns: + Dispatcher + + """ + if cls.__singleton is not None: + return cls.__singleton() + else: + raise RuntimeError('{} not initialized or multiple instances exist'.format( + cls.__name__)) + + def _pooled(self): + """ + A wrapper to run a thread in a thread pool + """ + thr_name = current_thread().getName() + while 1: + promise = self.__async_queue.get() + + # If unpacking fails, the thread pool is being closed from Updater._join_async_threads + if not isinstance(promise, Promise): + self.logger.debug("Closing run_async thread %s/%d", thr_name, + len(self.__async_threads)) + break + + try: + promise.run() + + except: + self.logger.exception("run_async function raised exception") + + def run_async(self, func, *args, **kwargs): + """Queue a function (with given args/kwargs) to be run asynchronously. + + Args: + func (function): The function to run in the thread. + args (Optional[tuple]): Arguments to `func`. + kwargs (Optional[dict]): Keyword arguments to `func`. + + Returns: + Promise + + """ + # TODO: handle exception in async threads + # set a threading.Event to notify caller thread + promise = Promise(func, args, kwargs) + self.__async_queue.put(promise) + return promise def start(self): """ @@ -183,6 +225,25 @@ class Dispatcher(object): sleep(0.1) self.__stop_event.clear() + # async threads must be join()ed only after the dispatcher thread was joined, + # otherwise we can still have new async threads dispatched + threads = list(self.__async_threads) + total = len(threads) + + # Stop all threads in the thread pool by put()ting one non-tuple per thread + for i in range(total): + self.__async_queue.put(None) + + for i, thr in enumerate(threads): + self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total)) + thr.join() + self.__async_threads.remove(thr) + self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total)) + + @property + def has_running_threads(self): + return self.running or bool(self.__async_threads) + def process_update(self, update): """ Processes a single update. diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index a9a4b3f67..969e23a13 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -29,8 +29,9 @@ from signal import signal, SIGINT, SIGTERM, SIGABRT from queue import Queue from telegram import Bot, TelegramError -from telegram.ext import dispatcher, Dispatcher, JobQueue +from telegram.ext import Dispatcher, JobQueue from telegram.error import Unauthorized, InvalidToken +from telegram.utils.request import Request from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler) logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -57,13 +58,17 @@ class Updater(object): base_url (Optional[str]): workers (Optional[int]): Amount of threads in the thread pool for functions decorated with @run_async - bot (Optional[Bot]): + bot (Optional[Bot]): A pre-initialized bot instance. If a pre-initizlied bot is used, it is + the user's responsibility to create it using a `Request` instance with a large enough + connection pool. job_queue_tick_interval(Optional[float]): The interval the queue should be checked for new tasks. Defaults to 1.0 Raises: ValueError: If both `token` and `bot` are passed or none of them. + """ + _request = None def __init__(self, token=None, base_url=None, workers=4, bot=None): if (token is None) and (bot is None): @@ -74,7 +79,14 @@ class Updater(object): if bot is not None: self.bot = bot else: - self.bot = Bot(token, base_url) + # we need a connection pool the size of: + # * for each of the workers + # * 1 for Dispatcher + # * 1 for polling Updater (even if webhook is used, we can spare a connection) + # * 1 for JobQueue + # * 1 for main thread + self._request = Request(con_pool_size=workers + 4) + self.bot = Bot(token, base_url, request=self._request) self.update_queue = Queue() self.job_queue = JobQueue(self.bot) self.__exception_event = Event() @@ -344,7 +356,7 @@ class Updater(object): self.job_queue.stop() with self.__lock: - if self.running or dispatcher.ASYNC_THREADS: + if self.running or self.dispatcher.has_running_threads: self.logger.debug('Stopping Updater and Dispatcher...') self.running = False @@ -352,9 +364,10 @@ class Updater(object): self._stop_httpd() self._stop_dispatcher() self._join_threads() - # async threads must be join()ed only after the dispatcher thread was joined, - # otherwise we can still have new async threads dispatched - self._join_async_threads() + + # Stop the Request instance only if it was created by the Updater + if self._request: + self._request.stop() def _stop_httpd(self): if self.httpd: @@ -368,21 +381,6 @@ class Updater(object): self.logger.debug('Requesting Dispatcher to stop...') self.dispatcher.stop() - def _join_async_threads(self): - with dispatcher.ASYNC_LOCK: - threads = list(dispatcher.ASYNC_THREADS) - total = len(threads) - - # Stop all threads in the thread pool by put()ting one non-tuple per thread - for i in range(total): - dispatcher.ASYNC_QUEUE.put(None) - - for i, thr in enumerate(threads): - self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total)) - thr.join() - dispatcher.ASYNC_THREADS.remove(thr) - self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total)) - def _join_threads(self): for thr in self.__threads: self.logger.debug('Waiting for {0} thread to end'.format(thr.name)) diff --git a/telegram/file.py b/telegram/file.py index 74d02a7c4..5c3ca8062 100644 --- a/telegram/file.py +++ b/telegram/file.py @@ -21,7 +21,6 @@ from os.path import basename from telegram import TelegramObject -from telegram.utils.request import download as _download class File(TelegramObject): @@ -34,38 +33,44 @@ class File(TelegramObject): Args: file_id (str): + request (telegram.utils.request.Request): **kwargs: Arbitrary keyword arguments. Keyword Args: file_size (Optional[int]): file_path (Optional[str]): + """ - def __init__(self, file_id, **kwargs): + def __init__(self, file_id, request, **kwargs): # Required self.file_id = str(file_id) + self._request = request # Optionals self.file_size = int(kwargs.get('file_size', 0)) self.file_path = str(kwargs.get('file_path', '')) @staticmethod - def de_json(data): + def de_json(data, request): """ Args: - data (str): + data (dict): + request (telegram.utils.request.Request): Returns: telegram.File: + """ if not data: return None - return File(**data) + return File(request=request, **data) def download(self, custom_path=None): """ Args: custom_path (str): + """ url = self.file_path @@ -74,4 +79,4 @@ class File(TelegramObject): else: filename = basename(url) - _download(url, filename) + self._request.download(url, filename) diff --git a/telegram/utils/request.py b/telegram/utils/request.py index 96e8d8c95..7bfcbb9e5 100644 --- a/telegram/utils/request.py +++ b/telegram/utils/request.py @@ -33,232 +33,185 @@ from urllib3.connection import HTTPConnection from telegram import (InputFile, TelegramError) from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated -_CON_POOL = None -""":type: urllib3.PoolManager""" -_CON_POOL_PROXY = None -_CON_POOL_PROXY_KWARGS = {} -CON_POOL_SIZE = 1 - logging.getLogger('urllib3').setLevel(logging.WARNING) -def _get_con_pool(): - if _CON_POOL is not None: - return _CON_POOL - - _init_con_pool() - return _CON_POOL - - -def _init_con_pool(): - global _CON_POOL - kwargs = dict( - maxsize=CON_POOL_SIZE, - cert_reqs='CERT_REQUIRED', - ca_certs=certifi.where(), - socket_options=HTTPConnection.default_socket_options + [ - (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), - ]) - proxy_url = _get_con_pool_proxy() - if not proxy_url: - mgr = urllib3.PoolManager(**kwargs) - else: - if _CON_POOL_PROXY_KWARGS: - kwargs.update(_CON_POOL_PROXY_KWARGS) - mgr = urllib3.proxy_from_url(proxy_url, **kwargs) - if mgr.proxy.auth: - # TODO: what about other auth types? - auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth) - mgr.proxy_headers.update(auth_hdrs) - - _CON_POOL = mgr - - -def is_con_pool_initialized(): - return _CON_POOL is not None - - -def stop_con_pool(): - global _CON_POOL - if _CON_POOL is not None: - _CON_POOL.clear() - _CON_POOL = None - - -def set_con_pool_proxy(url, **urllib3_kwargs): - """Setup connection pool behind a proxy +class Request(object): + """ + Helper class for python-telegram-bot which provides methods to perform POST & GET towards + telegram servers. Args: - url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128` - urllib3_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager` - - """ - global _CON_POOL_PROXY - global _CON_POOL_PROXY_KWARGS - - if is_con_pool_initialized(): - raise TelegramError('conpool already initialized') - - _CON_POOL_PROXY = url - _CON_POOL_PROXY_KWARGS = urllib3_kwargs - - -def _get_con_pool_proxy(): - """Return the user configured proxy according to the following order: - - * proxy configured using `set_con_pool_proxy()`. - * proxy set in `HTTPS_PROXY` env. var. - * proxy set in `https_proxy` env. var. - * None (if no proxy is configured) - - Returns: - str | None - - """ - if _CON_POOL_PROXY: - return _CON_POOL_PROXY - from_env = os.environ.get('HTTPS_PROXY') - if from_env: - return from_env - from_env = os.environ.get('https_proxy') - if from_env: - return from_env - return None - - -def _parse(json_data): - """Try and parse the JSON returned from Telegram. - - Returns: - dict: A JSON parsed as Python dict with results - on error this dict will be empty. - - """ - decoded_s = json_data.decode('utf-8') - try: - data = json.loads(decoded_s) - except ValueError: - raise TelegramError('Invalid server response') - - if not data.get('ok'): - description = data.get('description') - parameters = data.get('parameters') - if parameters: - migrate_to_chat_id = parameters.get('migrate_to_chat_id') - if migrate_to_chat_id: - raise ChatMigrated(migrate_to_chat_id) - if description: - return description - - return data['result'] - - -def _request_wrapper(*args, **kwargs): - """Wraps urllib3 request for handling known exceptions. - - Args: - args: unnamed arguments, passed to urllib3 request. - kwargs: keyword arguments, passed tp urllib3 request. - - Returns: - str: A non-parsed JSON text. - - Raises: - TelegramError + proxy_url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`. + urllib3_proxy_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`. + This value will be ignored if proxy_url is not set. """ - try: - resp = _get_con_pool().request(*args, **kwargs) - except urllib3.exceptions.TimeoutError as error: - raise TimedOut() - except urllib3.exceptions.HTTPError as error: - # HTTPError must come last as its the base urllib3 exception class - # TODO: do something smart here; for now just raise NetworkError - raise NetworkError('urllib3 HTTPError {0}'.format(error)) + def __init__(self, con_pool_size=1, proxy_url=None, urllib3_proxy_kwargs=None): + if urllib3_proxy_kwargs is None: + urllib3_proxy_kwargs = dict() - if 200 <= resp.status <= 299: - # 200-299 range are HTTP success statuses - return resp.data + kwargs = dict( + maxsize=con_pool_size, + cert_reqs='CERT_REQUIRED', + ca_certs=certifi.where(), + socket_options=HTTPConnection.default_socket_options + [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + ]) - try: - message = _parse(resp.data) - except ValueError: - raise NetworkError('Unknown HTTPError {0}'.format(resp.status)) + # Set a proxy according to the following order: + # * proxy defined in proxy_url (+ urllib3_proxy_kwargs) + # * proxy set in `HTTPS_PROXY` env. var. + # * proxy set in `https_proxy` env. var. + # * None (if no proxy is configured) - if resp.status in (401, 403): - raise Unauthorized() - elif resp.status == 400: - raise BadRequest(repr(message)) - elif resp.status == 502: - raise NetworkError('Bad Gateway') - else: - raise NetworkError('{0} ({1})'.format(message, resp.status)) + if not proxy_url: + proxy_url = os.environ.get('HTTPS_PROXY') or os.environ.get('https_proxy') + if not proxy_url: + mgr = urllib3.PoolManager(**kwargs) + else: + kwargs.update(urllib3_proxy_kwargs) + mgr = urllib3.proxy_from_url(proxy_url, **kwargs) + if mgr.proxy.auth: + # TODO: what about other auth types? + auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth) + mgr.proxy_headers.update(auth_hdrs) -def get(url): - """Request an URL. - Args: - url: - The web location we want to retrieve. + self._con_pool = mgr - Returns: - A JSON object. + def stop(self): + self._con_pool.clear() - """ - result = _request_wrapper('GET', url) + @staticmethod + def _parse(json_data): + """Try and parse the JSON returned from Telegram. - return _parse(result) + Returns: + dict: A JSON parsed as Python dict with results - on error this dict will be empty. + """ + decoded_s = json_data.decode('utf-8') + try: + data = json.loads(decoded_s) + except ValueError: + raise TelegramError('Invalid server response') -def post(url, data, timeout=None): - """Request an URL. - Args: - url: - The web location we want to retrieve. - data: - A dict of (str, unicode) key/value pairs. - timeout: - float. If this value is specified, use it as the definitive timeout (in - seconds) for urlopen() operations. [Optional] + if not data.get('ok'): + description = data.get('description') + parameters = data.get('parameters') + if parameters: + migrate_to_chat_id = parameters.get('migrate_to_chat_id') + if migrate_to_chat_id: + raise ChatMigrated(migrate_to_chat_id) + if description: + return description - Notes: - If neither `timeout` nor `data['timeout']` is specified. The underlying - defaults are used. + return data['result'] - Returns: - A JSON object. + def _request_wrapper(self, *args, **kwargs): + """Wraps urllib3 request for handling known exceptions. - """ - urlopen_kwargs = {} + Args: + args: unnamed arguments, passed to urllib3 request. + kwargs: keyword arguments, passed tp urllib3 request. - if timeout is not None: - urlopen_kwargs['timeout'] = timeout + Returns: + str: A non-parsed JSON text. - if InputFile.is_inputfile(data): - data = InputFile(data) - result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers) - else: - data = json.dumps(data) - result = _request_wrapper( - 'POST', - url, - body=data.encode(), - headers={'Content-Type': 'application/json'}, - **urlopen_kwargs) + Raises: + TelegramError - return _parse(result) + """ + try: + resp = self._con_pool.request(*args, **kwargs) + except urllib3.exceptions.TimeoutError: + raise TimedOut() + except urllib3.exceptions.HTTPError as error: + # HTTPError must come last as its the base urllib3 exception class + # TODO: do something smart here; for now just raise NetworkError + raise NetworkError('urllib3 HTTPError {0}'.format(error)) + if 200 <= resp.status <= 299: + # 200-299 range are HTTP success statuses + return resp.data -def download(url, filename): - """Download a file by its URL. - Args: - url: - The web location we want to retrieve. + try: + message = self._parse(resp.data) + except ValueError: + raise NetworkError('Unknown HTTPError {0}'.format(resp.status)) - filename: - The filename within the path to download the file. + if resp.status in (401, 403): + raise Unauthorized() + elif resp.status == 400: + raise BadRequest(repr(message)) + elif resp.status == 502: + raise NetworkError('Bad Gateway') + else: + raise NetworkError('{0} ({1})'.format(message, resp.status)) - """ - buf = _request_wrapper('GET', url) - with open(filename, 'wb') as fobj: - fobj.write(buf) + def get(self, url): + """Request an URL. + Args: + url: + The web location we want to retrieve. + + Returns: + A JSON object. + + """ + result = self._request_wrapper('GET', url) + return self._parse(result) + + def post(self, url, data, timeout=None): + """Request an URL. + Args: + url: + The web location we want to retrieve. + data: + A dict of (str, unicode) key/value pairs. + timeout: + float. If this value is specified, use it as the definitive timeout (in + seconds) for urlopen() operations. [Optional] + + Notes: + If neither `timeout` nor `data['timeout']` is specified. The underlying + defaults are used. + + Returns: + A JSON object. + + """ + urlopen_kwargs = {} + + if timeout is not None: + urlopen_kwargs['timeout'] = timeout + + if InputFile.is_inputfile(data): + data = InputFile(data) + result = self._request_wrapper('POST', url, body=data.to_form(), headers=data.headers) + else: + data = json.dumps(data) + result = self._request_wrapper( + 'POST', + url, + body=data.encode(), + headers={'Content-Type': 'application/json'}, + **urlopen_kwargs) + + return self._parse(result) + + def download(self, url, filename): + """Download a file by its URL. + Args: + url: + The web location we want to retrieve. + + filename: + The filename within the path to download the file. + + """ + buf = self._request_wrapper('GET', url) + with open(filename, 'wb') as fobj: + fobj.write(buf) diff --git a/tests/test_bot.py b/tests/test_bot.py index f80b7f9ba..1cf378814 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -198,7 +198,9 @@ class BotTest(BaseTest, unittest.TestCase): def testInvalidSrvResp(self): with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'): # bypass the valid token check - bot = telegram.Bot.__new__(telegram.Bot) + newbot_cls = type( + 'NoTokenValidateBot', (telegram.Bot,), dict(_validate_token=lambda x, y: None)) + bot = newbot_cls('0xdeadbeef') bot.base_url = 'https://api.telegram.org/bot{0}'.format('12') bot.getMe() diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index c22d5bf1b..6f6cd4ba3 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -36,7 +36,6 @@ except ImportError: sys.path.append('.') from telegram import Update, Message, TelegramError, User, Chat, Bot -from telegram.utils.request import stop_con_pool from telegram.ext import * from tests.base import BaseTest from tests.test_updater import MockBot @@ -61,10 +60,10 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): # At first we're thirsty. Then we brew coffee, we drink it # and then we can start coding! END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4) + _updater = None # Test related def setUp(self): - self.updater = None self.current_state = dict() self.entry_points = [CommandHandler('start', self.start)] self.states = {self.THIRSTY: [CommandHandler('brew', self.brew), @@ -78,14 +77,22 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): self.fallbacks = [CommandHandler('eat', self.start)] def _setup_updater(self, *args, **kwargs): - stop_con_pool() bot = MockBot(*args, **kwargs) self.updater = Updater(workers=2, bot=bot) def tearDown(self): if self.updater is not None: self.updater.stop() - stop_con_pool() + + @property + def updater(self): + return self._updater + + @updater.setter + def updater(self, val): + if self._updater: + self._updater.stop() + self._updater = val def reset(self): self.current_state = dict() diff --git a/tests/test_file.py b/tests/test_file.py index cf1795dfd..a58718ae7 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -101,19 +101,19 @@ class FileTest(BaseTest, unittest.TestCase): self.assertTrue(os.path.isfile('telegram.ogg')) def test_file_de_json(self): - newFile = telegram.File.de_json(self.json_dict) + newFile = telegram.File.de_json(self.json_dict, None) self.assertEqual(newFile.file_id, self.json_dict['file_id']) self.assertEqual(newFile.file_path, self.json_dict['file_path']) self.assertEqual(newFile.file_size, self.json_dict['file_size']) def test_file_to_json(self): - newFile = telegram.File.de_json(self.json_dict) + newFile = telegram.File.de_json(self.json_dict, None) self.assertTrue(self.is_json(newFile.to_json())) def test_file_to_dict(self): - newFile = telegram.File.de_json(self.json_dict) + newFile = telegram.File.de_json(self.json_dict, None) self.assertTrue(self.is_dict(newFile.to_dict())) self.assertEqual(newFile['file_id'], self.json_dict['file_id']) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index be481b937..acac73a0c 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -25,9 +25,10 @@ import sys import unittest from time import sleep +from tests.test_updater import MockBot + sys.path.append('.') -from telegram.utils.request import stop_con_pool from telegram.ext import JobQueue, Job, Updater from tests.base import BaseTest @@ -49,13 +50,12 @@ class JobQueueTest(BaseTest, unittest.TestCase): """ def setUp(self): - self.jq = JobQueue("Bot") + self.jq = JobQueue(MockBot('jobqueue_test')) self.result = 0 def tearDown(self): if self.jq is not None: self.jq.stop() - stop_con_pool() def job1(self, bot, job): self.result += 1 @@ -158,12 +158,15 @@ class JobQueueTest(BaseTest, unittest.TestCase): def test_inUpdater(self): u = Updater(bot="MockBot") - u.job_queue.put(Job(self.job1, 0.5)) - sleep(0.75) - self.assertEqual(1, self.result) - u.stop() - sleep(2) - self.assertEqual(1, self.result) + try: + u.job_queue.put(Job(self.job1, 0.5)) + sleep(0.75) + self.assertEqual(1, self.result) + u.stop() + sleep(2) + self.assertEqual(1, self.result) + finally: + u.stop() if __name__ == '__main__': diff --git a/tests/test_updater.py b/tests/test_updater.py index 763be399e..5b0e08995 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -29,10 +29,13 @@ import re import unittest from datetime import datetime from time import sleep +from queue import Queue from random import randrange from future.builtins import bytes +from telegram.utils.request import Request as Requester + try: # python2 from urllib2 import urlopen, Request, HTTPError @@ -44,12 +47,11 @@ except ImportError: sys.path.append('.') from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery -from telegram.utils.request import stop_con_pool from telegram.ext import * from telegram.ext.dispatcher import run_async from telegram.error import Unauthorized, InvalidToken from tests.base import BaseTest -from threading import Lock, Thread +from threading import Lock, Thread, current_thread, Semaphore # Enable logging root = logging.getLogger() @@ -68,10 +70,8 @@ class UpdaterTest(BaseTest, unittest.TestCase): WebhookHandler """ - updater = None + _updater = None received_message = None - message_count = None - lock = None def setUp(self): self.updater = None @@ -79,15 +79,25 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.message_count = 0 self.lock = Lock() + @property + def updater(self): + return self._updater + + @updater.setter + def updater(self, val): + if self._updater: + self._updater.stop() + self._updater.dispatcher._reset_singleton() + del self._updater.dispatcher + + self._updater = val + def _setup_updater(self, *args, **kwargs): - stop_con_pool() bot = MockBot(*args, **kwargs) self.updater = Updater(workers=2, bot=bot) def tearDown(self): - if self.updater is not None: - self.updater.stop() - stop_con_pool() + self.updater = None def reset(self): self.message_count = 0 @@ -411,6 +421,51 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.assertEqual(self.received_message, 'Test5') self.assertEqual(self.message_count, 2) + def test_multiple_dispatchers(self): + + def get_dispatcher_name(q): + q.put(current_thread().name) + sleep(1.2) + + d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1) + d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1) + q1 = Queue() + q2 = Queue() + + try: + d1.run_async(get_dispatcher_name, q1) + d2.run_async(get_dispatcher_name, q2) + + name1 = q1.get() + name2 = q2.get() + + self.assertNotEqual(name1, name2) + finally: + d1.stop() + d2.stop() + # following three lines are for pypy unitests + d1._reset_singleton() + del d1 + del d2 + + def test_multiple_dispatcers_no_decorator(self): + + @run_async + def must_raise_runtime_error(): + pass + + d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1) + d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1) + + self.assertRaises(RuntimeError, must_raise_runtime_error) + + d1.stop() + d2.stop() + # following three lines are for pypy unitests + d1._reset_singleton() + del d1 + del d2 + def test_additionalArgs(self): self._setup_updater('', messages=0) handler = StringCommandHandler(