Reusable dispatcher (#402)

* Create a Request class which maintains its own connection pool
* When creating a Bot instance a new Request instance will be created if one wasn't supplied.
* Updater is responsible for creating a Request instance if a Bot instance wasn't provided.
* Dispatcher: add method to run async functions without decorator
* Dispatcher can now run as a singleton (allowing run_async decorator to work) as it always did and as multiple instances (where run_async decorator will raise RuntimeError)
This commit is contained in:
Noam Meltzer 2016-09-06 16:38:07 +03:00 committed by GitHub
parent ca81a75f29
commit e4a132c0e4
11 changed files with 420 additions and 331 deletions

View file

@ -1,15 +1,15 @@
- repo: git://github.com/pre-commit/mirrors-yapf - repo: git://github.com/pre-commit/mirrors-yapf
sha: 34303f2856d4e4ba26dc302d9c28632e9b5a8626 sha: v0.11.0
hooks: hooks:
- id: yapf - id: yapf
files: ^(telegram|tests)/.*\.py$ files: ^(telegram|tests)/.*\.py$
- repo: git://github.com/pre-commit/pre-commit-hooks - repo: git://github.com/pre-commit/pre-commit-hooks
sha: 3fa02652357ff0dbb42b5bc78c673b7bc105fcf3 sha: 18d7035de5388cc7775be57f529c154bf541aab9
hooks: hooks:
- id: flake8 - id: flake8
files: ^telegram/.*\.py$ files: ^telegram/.*\.py$
- repo: git://github.com/pre-commit/mirrors-pylint - repo: git://github.com/pre-commit/mirrors-pylint
sha: 4de6c8dfadef1a271a814561ce05b8bc1c446d22 sha: v1.5.5
hooks: hooks:
- id: pylint - id: pylint
files: ^telegram/.*\.py$ files: ^telegram/.*\.py$

View file

@ -25,7 +25,7 @@ import logging
from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File, from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File,
ReplyMarkup, TelegramObject) ReplyMarkup, TelegramObject)
from telegram.error import InvalidToken from telegram.error import InvalidToken
from telegram.utils import request from telegram.utils.request import Request
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -44,10 +44,11 @@ class Bot(TelegramObject):
token (str): Bot's unique authentication. token (str): Bot's unique authentication.
base_url (Optional[str]): Telegram Bot API service URL. base_url (Optional[str]): Telegram Bot API service URL.
base_file_url (Optional[str]): Telegram Bot API file URL. base_file_url (Optional[str]): Telegram Bot API file URL.
request (Optional[Request]): Pre initialized `Request` class.
""" """
def __init__(self, token, base_url=None, base_file_url=None): def __init__(self, token, base_url=None, base_file_url=None, request=None):
self.token = self._validate_token(token) self.token = self._validate_token(token)
if not base_url: if not base_url:
@ -61,9 +62,13 @@ class Bot(TelegramObject):
self.base_file_url = base_file_url + self.token self.base_file_url = base_file_url + self.token
self.bot = None self.bot = None
self._request = request or Request()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@property
def request(self):
return self._request
@staticmethod @staticmethod
def _validate_token(token): def _validate_token(token):
"""a very basic validation on token""" """a very basic validation on token"""
@ -144,7 +149,7 @@ class Bot(TelegramObject):
else: else:
data['reply_markup'] = reply_markup data['reply_markup'] = reply_markup
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
if result is True: if result is True:
return result return result
@ -169,7 +174,7 @@ class Bot(TelegramObject):
url = '{0}/getMe'.format(self.base_url) url = '{0}/getMe'.format(self.base_url)
result = request.get(url) result = self._request.get(url)
self.bot = User.de_json(result) self.bot = User.de_json(result)
@ -813,7 +818,7 @@ class Bot(TelegramObject):
if switch_pm_parameter: if switch_pm_parameter:
data['switch_pm_parameter'] = switch_pm_parameter data['switch_pm_parameter'] = switch_pm_parameter
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -853,7 +858,7 @@ class Bot(TelegramObject):
if limit: if limit:
data['limit'] = limit data['limit'] = limit
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return UserProfilePhotos.de_json(result) return UserProfilePhotos.de_json(result)
@ -884,12 +889,12 @@ class Bot(TelegramObject):
data = {'file_id': file_id} data = {'file_id': file_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
if result.get('file_path'): if result.get('file_path'):
result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path']) result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path'])
return File.de_json(result) return File.de_json(result, self._request)
@log @log
def kickChatMember(self, chat_id, user_id, **kwargs): def kickChatMember(self, chat_id, user_id, **kwargs):
@ -921,7 +926,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id} data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -955,7 +960,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id} data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -999,7 +1004,7 @@ class Bot(TelegramObject):
if show_alert: if show_alert:
data['show_alert'] = show_alert data['show_alert'] = show_alert
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -1213,7 +1218,7 @@ class Bot(TelegramObject):
urlopen_timeout = timeout + network_delay urlopen_timeout = timeout + network_delay
result = request.post(url, data, timeout=urlopen_timeout) result = self._request.post(url, data, timeout=urlopen_timeout)
if result: if result:
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result]) self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
@ -1256,7 +1261,7 @@ class Bot(TelegramObject):
if certificate: if certificate:
data['certificate'] = certificate data['certificate'] = certificate
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -1286,7 +1291,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id} data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -1318,7 +1323,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id} data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return Chat.de_json(result) return Chat.de_json(result)
@ -1353,7 +1358,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id} data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return [ChatMember.de_json(x) for x in result] return [ChatMember.de_json(x) for x in result]
@ -1383,7 +1388,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id} data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result return result
@ -1416,7 +1421,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id} data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout')) result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return ChatMember.de_json(result) return ChatMember.de_json(result)

View file

@ -19,70 +19,43 @@
"""This module contains the Dispatcher class.""" """This module contains the Dispatcher class."""
import logging import logging
import weakref
from functools import wraps from functools import wraps
from threading import Thread, Lock, Event, current_thread from threading import Thread, Lock, Event, current_thread, BoundedSemaphore
from time import sleep from time import sleep
from uuid import uuid4
from queue import Queue, Empty from queue import Queue, Empty
from future.builtins import range from future.builtins import range
from telegram import TelegramError from telegram import TelegramError
from telegram.utils import request
from telegram.ext.handler import Handler from telegram.ext.handler import Handler
from telegram.utils.deprecate import deprecate from telegram.utils.deprecate import deprecate
from telegram.utils.promise import Promise from telegram.utils.promise import Promise
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
ASYNC_QUEUE = Queue()
ASYNC_THREADS = set()
""":type: set[Thread]""" """:type: set[Thread]"""
ASYNC_LOCK = Lock() # guards ASYNC_THREADS
DEFAULT_GROUP = 0 DEFAULT_GROUP = 0
def _pooled():
"""
A wrapper to run a thread in a thread pool
"""
while 1:
promise = ASYNC_QUEUE.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
if not isinstance(promise, Promise):
logging.getLogger(__name__).debug("Closing run_async thread %s/%d" %
(current_thread().getName(), len(ASYNC_THREADS)))
break
try:
promise.run()
except:
logging.getLogger(__name__).exception("run_async function raised exception")
def run_async(func): def run_async(func):
""" """Function decorator that will run the function in a new thread.
Function decorator that will run the function in a new thread.
Using this decorator is only possible when only a single Dispatcher exist in the system.
Args: Args:
func (function): The function to run in the thread. func (function): The function to run in the thread.
async_queue (Queue): The queue of the functions to be executed asynchronously.
Returns: Returns:
function: function:
"""
# TODO: handle exception in async threads """
# set a threading.Event to notify caller thread
@wraps(func) @wraps(func)
def async_func(*args, **kwargs): def async_func(*args, **kwargs):
""" return Dispatcher.get_instance().run_async(func, *args, **kwargs)
A wrapper to run a function in a thread
"""
promise = Promise(func, args, kwargs)
ASYNC_QUEUE.put(promise)
return promise
return async_func return async_func
@ -100,7 +73,12 @@ class Dispatcher(object):
callbacks callbacks
workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async`` workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async``
decorator decorator
""" """
__singleton_lock = Lock()
__singleton_semaphore = BoundedSemaphore()
__singleton = None
logger = logging.getLogger(__name__)
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None): def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None):
self.bot = bot self.bot = bot
@ -113,28 +91,92 @@ class Dispatcher(object):
""":type: list[int]""" """:type: list[int]"""
self.error_handlers = [] self.error_handlers = []
self.logger = logging.getLogger(__name__)
self.running = False self.running = False
self.__stop_event = Event() self.__stop_event = Event()
self.__exception_event = exception_event or Event() self.__exception_event = exception_event or Event()
self.__async_queue = Queue()
self.__async_threads = set()
with ASYNC_LOCK: # For backward compatibility, we allow a "singleton" mode for the dispatcher. When there's
if not ASYNC_THREADS: # only one instance of Dispatcher, it will be possible to use the `run_async` decorator.
if request.is_con_pool_initialized(): with self.__singleton_lock:
raise RuntimeError('Connection Pool already initialized') if self.__singleton_semaphore.acquire(blocking=0):
self._set_singleton(self)
# we need a connection pool the size of:
# * for each of the workers
# * 1 for Dispatcher
# * 1 for polling Updater (even if updater is webhook, we can spare a connection)
# * 1 for JobQueue
request.CON_POOL_SIZE = workers + 3
for i in range(workers):
thread = Thread(target=_pooled, name=str(i))
ASYNC_THREADS.add(thread)
thread.start()
else: else:
self.logger.debug('Thread pool already initialized, skipping.') self._set_singleton(None)
self._init_async_threads(uuid4(), workers)
@classmethod
def _reset_singleton(cls):
# NOTE: This method was added mainly for test_updater benefit and specifically pypy. Never
# call it in production code.
cls.__singleton_semaphore.release()
def _init_async_threads(self, base_name, workers):
base_name = '{}_'.format(base_name) if base_name else ''
for i in range(workers):
thread = Thread(target=self._pooled, name='{}{}'.format(base_name, i))
self.__async_threads.add(thread)
thread.start()
@classmethod
def _set_singleton(cls, val):
cls.logger.debug('Setting singleton dispatcher as %s', val)
cls.__singleton = weakref.ref(val) if val else None
@classmethod
def get_instance(cls):
"""Get the singleton instance of this class.
Returns:
Dispatcher
"""
if cls.__singleton is not None:
return cls.__singleton()
else:
raise RuntimeError('{} not initialized or multiple instances exist'.format(
cls.__name__))
def _pooled(self):
"""
A wrapper to run a thread in a thread pool
"""
thr_name = current_thread().getName()
while 1:
promise = self.__async_queue.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
if not isinstance(promise, Promise):
self.logger.debug("Closing run_async thread %s/%d", thr_name,
len(self.__async_threads))
break
try:
promise.run()
except:
self.logger.exception("run_async function raised exception")
def run_async(self, func, *args, **kwargs):
"""Queue a function (with given args/kwargs) to be run asynchronously.
Args:
func (function): The function to run in the thread.
args (Optional[tuple]): Arguments to `func`.
kwargs (Optional[dict]): Keyword arguments to `func`.
Returns:
Promise
"""
# TODO: handle exception in async threads
# set a threading.Event to notify caller thread
promise = Promise(func, args, kwargs)
self.__async_queue.put(promise)
return promise
def start(self): def start(self):
""" """
@ -183,6 +225,25 @@ class Dispatcher(object):
sleep(0.1) sleep(0.1)
self.__stop_event.clear() self.__stop_event.clear()
# async threads must be join()ed only after the dispatcher thread was joined,
# otherwise we can still have new async threads dispatched
threads = list(self.__async_threads)
total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
self.__async_queue.put(None)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join()
self.__async_threads.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
@property
def has_running_threads(self):
return self.running or bool(self.__async_threads)
def process_update(self, update): def process_update(self, update):
""" """
Processes a single update. Processes a single update.

View file

@ -29,8 +29,9 @@ from signal import signal, SIGINT, SIGTERM, SIGABRT
from queue import Queue from queue import Queue
from telegram import Bot, TelegramError from telegram import Bot, TelegramError
from telegram.ext import dispatcher, Dispatcher, JobQueue from telegram.ext import Dispatcher, JobQueue
from telegram.error import Unauthorized, InvalidToken from telegram.error import Unauthorized, InvalidToken
from telegram.utils.request import Request
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler) from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -57,13 +58,17 @@ class Updater(object):
base_url (Optional[str]): base_url (Optional[str]):
workers (Optional[int]): Amount of threads in the thread pool for workers (Optional[int]): Amount of threads in the thread pool for
functions decorated with @run_async functions decorated with @run_async
bot (Optional[Bot]): bot (Optional[Bot]): A pre-initialized bot instance. If a pre-initizlied bot is used, it is
the user's responsibility to create it using a `Request` instance with a large enough
connection pool.
job_queue_tick_interval(Optional[float]): The interval the queue should job_queue_tick_interval(Optional[float]): The interval the queue should
be checked for new tasks. Defaults to 1.0 be checked for new tasks. Defaults to 1.0
Raises: Raises:
ValueError: If both `token` and `bot` are passed or none of them. ValueError: If both `token` and `bot` are passed or none of them.
""" """
_request = None
def __init__(self, token=None, base_url=None, workers=4, bot=None): def __init__(self, token=None, base_url=None, workers=4, bot=None):
if (token is None) and (bot is None): if (token is None) and (bot is None):
@ -74,7 +79,14 @@ class Updater(object):
if bot is not None: if bot is not None:
self.bot = bot self.bot = bot
else: else:
self.bot = Bot(token, base_url) # we need a connection pool the size of:
# * for each of the workers
# * 1 for Dispatcher
# * 1 for polling Updater (even if webhook is used, we can spare a connection)
# * 1 for JobQueue
# * 1 for main thread
self._request = Request(con_pool_size=workers + 4)
self.bot = Bot(token, base_url, request=self._request)
self.update_queue = Queue() self.update_queue = Queue()
self.job_queue = JobQueue(self.bot) self.job_queue = JobQueue(self.bot)
self.__exception_event = Event() self.__exception_event = Event()
@ -344,7 +356,7 @@ class Updater(object):
self.job_queue.stop() self.job_queue.stop()
with self.__lock: with self.__lock:
if self.running or dispatcher.ASYNC_THREADS: if self.running or self.dispatcher.has_running_threads:
self.logger.debug('Stopping Updater and Dispatcher...') self.logger.debug('Stopping Updater and Dispatcher...')
self.running = False self.running = False
@ -352,9 +364,10 @@ class Updater(object):
self._stop_httpd() self._stop_httpd()
self._stop_dispatcher() self._stop_dispatcher()
self._join_threads() self._join_threads()
# async threads must be join()ed only after the dispatcher thread was joined,
# otherwise we can still have new async threads dispatched # Stop the Request instance only if it was created by the Updater
self._join_async_threads() if self._request:
self._request.stop()
def _stop_httpd(self): def _stop_httpd(self):
if self.httpd: if self.httpd:
@ -368,21 +381,6 @@ class Updater(object):
self.logger.debug('Requesting Dispatcher to stop...') self.logger.debug('Requesting Dispatcher to stop...')
self.dispatcher.stop() self.dispatcher.stop()
def _join_async_threads(self):
with dispatcher.ASYNC_LOCK:
threads = list(dispatcher.ASYNC_THREADS)
total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
dispatcher.ASYNC_QUEUE.put(None)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join()
dispatcher.ASYNC_THREADS.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
def _join_threads(self): def _join_threads(self):
for thr in self.__threads: for thr in self.__threads:
self.logger.debug('Waiting for {0} thread to end'.format(thr.name)) self.logger.debug('Waiting for {0} thread to end'.format(thr.name))

View file

@ -21,7 +21,6 @@
from os.path import basename from os.path import basename
from telegram import TelegramObject from telegram import TelegramObject
from telegram.utils.request import download as _download
class File(TelegramObject): class File(TelegramObject):
@ -34,38 +33,44 @@ class File(TelegramObject):
Args: Args:
file_id (str): file_id (str):
request (telegram.utils.request.Request):
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Keyword Args: Keyword Args:
file_size (Optional[int]): file_size (Optional[int]):
file_path (Optional[str]): file_path (Optional[str]):
""" """
def __init__(self, file_id, **kwargs): def __init__(self, file_id, request, **kwargs):
# Required # Required
self.file_id = str(file_id) self.file_id = str(file_id)
self._request = request
# Optionals # Optionals
self.file_size = int(kwargs.get('file_size', 0)) self.file_size = int(kwargs.get('file_size', 0))
self.file_path = str(kwargs.get('file_path', '')) self.file_path = str(kwargs.get('file_path', ''))
@staticmethod @staticmethod
def de_json(data): def de_json(data, request):
""" """
Args: Args:
data (str): data (dict):
request (telegram.utils.request.Request):
Returns: Returns:
telegram.File: telegram.File:
""" """
if not data: if not data:
return None return None
return File(**data) return File(request=request, **data)
def download(self, custom_path=None): def download(self, custom_path=None):
""" """
Args: Args:
custom_path (str): custom_path (str):
""" """
url = self.file_path url = self.file_path
@ -74,4 +79,4 @@ class File(TelegramObject):
else: else:
filename = basename(url) filename = basename(url)
_download(url, filename) self._request.download(url, filename)

View file

@ -33,232 +33,185 @@ from urllib3.connection import HTTPConnection
from telegram import (InputFile, TelegramError) from telegram import (InputFile, TelegramError)
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated
_CON_POOL = None
""":type: urllib3.PoolManager"""
_CON_POOL_PROXY = None
_CON_POOL_PROXY_KWARGS = {}
CON_POOL_SIZE = 1
logging.getLogger('urllib3').setLevel(logging.WARNING) logging.getLogger('urllib3').setLevel(logging.WARNING)
def _get_con_pool(): class Request(object):
if _CON_POOL is not None: """
return _CON_POOL Helper class for python-telegram-bot which provides methods to perform POST & GET towards
telegram servers.
_init_con_pool()
return _CON_POOL
def _init_con_pool():
global _CON_POOL
kwargs = dict(
maxsize=CON_POOL_SIZE,
cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
proxy_url = _get_con_pool_proxy()
if not proxy_url:
mgr = urllib3.PoolManager(**kwargs)
else:
if _CON_POOL_PROXY_KWARGS:
kwargs.update(_CON_POOL_PROXY_KWARGS)
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
if mgr.proxy.auth:
# TODO: what about other auth types?
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
mgr.proxy_headers.update(auth_hdrs)
_CON_POOL = mgr
def is_con_pool_initialized():
return _CON_POOL is not None
def stop_con_pool():
global _CON_POOL
if _CON_POOL is not None:
_CON_POOL.clear()
_CON_POOL = None
def set_con_pool_proxy(url, **urllib3_kwargs):
"""Setup connection pool behind a proxy
Args: Args:
url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128` proxy_url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`.
urllib3_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager` urllib3_proxy_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`.
This value will be ignored if proxy_url is not set.
"""
global _CON_POOL_PROXY
global _CON_POOL_PROXY_KWARGS
if is_con_pool_initialized():
raise TelegramError('conpool already initialized')
_CON_POOL_PROXY = url
_CON_POOL_PROXY_KWARGS = urllib3_kwargs
def _get_con_pool_proxy():
"""Return the user configured proxy according to the following order:
* proxy configured using `set_con_pool_proxy()`.
* proxy set in `HTTPS_PROXY` env. var.
* proxy set in `https_proxy` env. var.
* None (if no proxy is configured)
Returns:
str | None
"""
if _CON_POOL_PROXY:
return _CON_POOL_PROXY
from_env = os.environ.get('HTTPS_PROXY')
if from_env:
return from_env
from_env = os.environ.get('https_proxy')
if from_env:
return from_env
return None
def _parse(json_data):
"""Try and parse the JSON returned from Telegram.
Returns:
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
"""
decoded_s = json_data.decode('utf-8')
try:
data = json.loads(decoded_s)
except ValueError:
raise TelegramError('Invalid server response')
if not data.get('ok'):
description = data.get('description')
parameters = data.get('parameters')
if parameters:
migrate_to_chat_id = parameters.get('migrate_to_chat_id')
if migrate_to_chat_id:
raise ChatMigrated(migrate_to_chat_id)
if description:
return description
return data['result']
def _request_wrapper(*args, **kwargs):
"""Wraps urllib3 request for handling known exceptions.
Args:
args: unnamed arguments, passed to urllib3 request.
kwargs: keyword arguments, passed tp urllib3 request.
Returns:
str: A non-parsed JSON text.
Raises:
TelegramError
""" """
try: def __init__(self, con_pool_size=1, proxy_url=None, urllib3_proxy_kwargs=None):
resp = _get_con_pool().request(*args, **kwargs) if urllib3_proxy_kwargs is None:
except urllib3.exceptions.TimeoutError as error: urllib3_proxy_kwargs = dict()
raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
if 200 <= resp.status <= 299: kwargs = dict(
# 200-299 range are HTTP success statuses maxsize=con_pool_size,
return resp.data cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
try: # Set a proxy according to the following order:
message = _parse(resp.data) # * proxy defined in proxy_url (+ urllib3_proxy_kwargs)
except ValueError: # * proxy set in `HTTPS_PROXY` env. var.
raise NetworkError('Unknown HTTPError {0}'.format(resp.status)) # * proxy set in `https_proxy` env. var.
# * None (if no proxy is configured)
if resp.status in (401, 403): if not proxy_url:
raise Unauthorized() proxy_url = os.environ.get('HTTPS_PROXY') or os.environ.get('https_proxy')
elif resp.status == 400:
raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
if not proxy_url:
mgr = urllib3.PoolManager(**kwargs)
else:
kwargs.update(urllib3_proxy_kwargs)
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
if mgr.proxy.auth:
# TODO: what about other auth types?
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
mgr.proxy_headers.update(auth_hdrs)
def get(url): self._con_pool = mgr
"""Request an URL.
Args:
url:
The web location we want to retrieve.
Returns: def stop(self):
A JSON object. self._con_pool.clear()
""" @staticmethod
result = _request_wrapper('GET', url) def _parse(json_data):
"""Try and parse the JSON returned from Telegram.
return _parse(result) Returns:
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
"""
decoded_s = json_data.decode('utf-8')
try:
data = json.loads(decoded_s)
except ValueError:
raise TelegramError('Invalid server response')
def post(url, data, timeout=None): if not data.get('ok'):
"""Request an URL. description = data.get('description')
Args: parameters = data.get('parameters')
url: if parameters:
The web location we want to retrieve. migrate_to_chat_id = parameters.get('migrate_to_chat_id')
data: if migrate_to_chat_id:
A dict of (str, unicode) key/value pairs. raise ChatMigrated(migrate_to_chat_id)
timeout: if description:
float. If this value is specified, use it as the definitive timeout (in return description
seconds) for urlopen() operations. [Optional]
Notes: return data['result']
If neither `timeout` nor `data['timeout']` is specified. The underlying
defaults are used.
Returns: def _request_wrapper(self, *args, **kwargs):
A JSON object. """Wraps urllib3 request for handling known exceptions.
""" Args:
urlopen_kwargs = {} args: unnamed arguments, passed to urllib3 request.
kwargs: keyword arguments, passed tp urllib3 request.
if timeout is not None: Returns:
urlopen_kwargs['timeout'] = timeout str: A non-parsed JSON text.
if InputFile.is_inputfile(data): Raises:
data = InputFile(data) TelegramError
result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else:
data = json.dumps(data)
result = _request_wrapper(
'POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'},
**urlopen_kwargs)
return _parse(result) """
try:
resp = self._con_pool.request(*args, **kwargs)
except urllib3.exceptions.TimeoutError:
raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
if 200 <= resp.status <= 299:
# 200-299 range are HTTP success statuses
return resp.data
def download(url, filename): try:
"""Download a file by its URL. message = self._parse(resp.data)
Args: except ValueError:
url: raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
The web location we want to retrieve.
filename: if resp.status in (401, 403):
The filename within the path to download the file. raise Unauthorized()
elif resp.status == 400:
raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
""" def get(self, url):
buf = _request_wrapper('GET', url) """Request an URL.
with open(filename, 'wb') as fobj: Args:
fobj.write(buf) url:
The web location we want to retrieve.
Returns:
A JSON object.
"""
result = self._request_wrapper('GET', url)
return self._parse(result)
def post(self, url, data, timeout=None):
"""Request an URL.
Args:
url:
The web location we want to retrieve.
data:
A dict of (str, unicode) key/value pairs.
timeout:
float. If this value is specified, use it as the definitive timeout (in
seconds) for urlopen() operations. [Optional]
Notes:
If neither `timeout` nor `data['timeout']` is specified. The underlying
defaults are used.
Returns:
A JSON object.
"""
urlopen_kwargs = {}
if timeout is not None:
urlopen_kwargs['timeout'] = timeout
if InputFile.is_inputfile(data):
data = InputFile(data)
result = self._request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else:
data = json.dumps(data)
result = self._request_wrapper(
'POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'},
**urlopen_kwargs)
return self._parse(result)
def download(self, url, filename):
"""Download a file by its URL.
Args:
url:
The web location we want to retrieve.
filename:
The filename within the path to download the file.
"""
buf = self._request_wrapper('GET', url)
with open(filename, 'wb') as fobj:
fobj.write(buf)

View file

@ -198,7 +198,9 @@ class BotTest(BaseTest, unittest.TestCase):
def testInvalidSrvResp(self): def testInvalidSrvResp(self):
with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'): with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'):
# bypass the valid token check # bypass the valid token check
bot = telegram.Bot.__new__(telegram.Bot) newbot_cls = type(
'NoTokenValidateBot', (telegram.Bot,), dict(_validate_token=lambda x, y: None))
bot = newbot_cls('0xdeadbeef')
bot.base_url = 'https://api.telegram.org/bot{0}'.format('12') bot.base_url = 'https://api.telegram.org/bot{0}'.format('12')
bot.getMe() bot.getMe()

View file

@ -36,7 +36,6 @@ except ImportError:
sys.path.append('.') sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot from telegram import Update, Message, TelegramError, User, Chat, Bot
from telegram.utils.request import stop_con_pool
from telegram.ext import * from telegram.ext import *
from tests.base import BaseTest from tests.base import BaseTest
from tests.test_updater import MockBot from tests.test_updater import MockBot
@ -61,10 +60,10 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
# At first we're thirsty. Then we brew coffee, we drink it # At first we're thirsty. Then we brew coffee, we drink it
# and then we can start coding! # and then we can start coding!
END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4) END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4)
_updater = None
# Test related # Test related
def setUp(self): def setUp(self):
self.updater = None
self.current_state = dict() self.current_state = dict()
self.entry_points = [CommandHandler('start', self.start)] self.entry_points = [CommandHandler('start', self.start)]
self.states = {self.THIRSTY: [CommandHandler('brew', self.brew), self.states = {self.THIRSTY: [CommandHandler('brew', self.brew),
@ -78,14 +77,22 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
self.fallbacks = [CommandHandler('eat', self.start)] self.fallbacks = [CommandHandler('eat', self.start)]
def _setup_updater(self, *args, **kwargs): def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs) bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot) self.updater = Updater(workers=2, bot=bot)
def tearDown(self): def tearDown(self):
if self.updater is not None: if self.updater is not None:
self.updater.stop() self.updater.stop()
stop_con_pool()
@property
def updater(self):
return self._updater
@updater.setter
def updater(self, val):
if self._updater:
self._updater.stop()
self._updater = val
def reset(self): def reset(self):
self.current_state = dict() self.current_state = dict()

View file

@ -101,19 +101,19 @@ class FileTest(BaseTest, unittest.TestCase):
self.assertTrue(os.path.isfile('telegram.ogg')) self.assertTrue(os.path.isfile('telegram.ogg'))
def test_file_de_json(self): def test_file_de_json(self):
newFile = telegram.File.de_json(self.json_dict) newFile = telegram.File.de_json(self.json_dict, None)
self.assertEqual(newFile.file_id, self.json_dict['file_id']) self.assertEqual(newFile.file_id, self.json_dict['file_id'])
self.assertEqual(newFile.file_path, self.json_dict['file_path']) self.assertEqual(newFile.file_path, self.json_dict['file_path'])
self.assertEqual(newFile.file_size, self.json_dict['file_size']) self.assertEqual(newFile.file_size, self.json_dict['file_size'])
def test_file_to_json(self): def test_file_to_json(self):
newFile = telegram.File.de_json(self.json_dict) newFile = telegram.File.de_json(self.json_dict, None)
self.assertTrue(self.is_json(newFile.to_json())) self.assertTrue(self.is_json(newFile.to_json()))
def test_file_to_dict(self): def test_file_to_dict(self):
newFile = telegram.File.de_json(self.json_dict) newFile = telegram.File.de_json(self.json_dict, None)
self.assertTrue(self.is_dict(newFile.to_dict())) self.assertTrue(self.is_dict(newFile.to_dict()))
self.assertEqual(newFile['file_id'], self.json_dict['file_id']) self.assertEqual(newFile['file_id'], self.json_dict['file_id'])

View file

@ -25,9 +25,10 @@ import sys
import unittest import unittest
from time import sleep from time import sleep
from tests.test_updater import MockBot
sys.path.append('.') sys.path.append('.')
from telegram.utils.request import stop_con_pool
from telegram.ext import JobQueue, Job, Updater from telegram.ext import JobQueue, Job, Updater
from tests.base import BaseTest from tests.base import BaseTest
@ -49,13 +50,12 @@ class JobQueueTest(BaseTest, unittest.TestCase):
""" """
def setUp(self): def setUp(self):
self.jq = JobQueue("Bot") self.jq = JobQueue(MockBot('jobqueue_test'))
self.result = 0 self.result = 0
def tearDown(self): def tearDown(self):
if self.jq is not None: if self.jq is not None:
self.jq.stop() self.jq.stop()
stop_con_pool()
def job1(self, bot, job): def job1(self, bot, job):
self.result += 1 self.result += 1
@ -158,12 +158,15 @@ class JobQueueTest(BaseTest, unittest.TestCase):
def test_inUpdater(self): def test_inUpdater(self):
u = Updater(bot="MockBot") u = Updater(bot="MockBot")
u.job_queue.put(Job(self.job1, 0.5)) try:
sleep(0.75) u.job_queue.put(Job(self.job1, 0.5))
self.assertEqual(1, self.result) sleep(0.75)
u.stop() self.assertEqual(1, self.result)
sleep(2) u.stop()
self.assertEqual(1, self.result) sleep(2)
self.assertEqual(1, self.result)
finally:
u.stop()
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -29,10 +29,13 @@ import re
import unittest import unittest
from datetime import datetime from datetime import datetime
from time import sleep from time import sleep
from queue import Queue
from random import randrange from random import randrange
from future.builtins import bytes from future.builtins import bytes
from telegram.utils.request import Request as Requester
try: try:
# python2 # python2
from urllib2 import urlopen, Request, HTTPError from urllib2 import urlopen, Request, HTTPError
@ -44,12 +47,11 @@ except ImportError:
sys.path.append('.') sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery
from telegram.utils.request import stop_con_pool
from telegram.ext import * from telegram.ext import *
from telegram.ext.dispatcher import run_async from telegram.ext.dispatcher import run_async
from telegram.error import Unauthorized, InvalidToken from telegram.error import Unauthorized, InvalidToken
from tests.base import BaseTest from tests.base import BaseTest
from threading import Lock, Thread from threading import Lock, Thread, current_thread, Semaphore
# Enable logging # Enable logging
root = logging.getLogger() root = logging.getLogger()
@ -68,10 +70,8 @@ class UpdaterTest(BaseTest, unittest.TestCase):
WebhookHandler WebhookHandler
""" """
updater = None _updater = None
received_message = None received_message = None
message_count = None
lock = None
def setUp(self): def setUp(self):
self.updater = None self.updater = None
@ -79,15 +79,25 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.message_count = 0 self.message_count = 0
self.lock = Lock() self.lock = Lock()
@property
def updater(self):
return self._updater
@updater.setter
def updater(self, val):
if self._updater:
self._updater.stop()
self._updater.dispatcher._reset_singleton()
del self._updater.dispatcher
self._updater = val
def _setup_updater(self, *args, **kwargs): def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs) bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot) self.updater = Updater(workers=2, bot=bot)
def tearDown(self): def tearDown(self):
if self.updater is not None: self.updater = None
self.updater.stop()
stop_con_pool()
def reset(self): def reset(self):
self.message_count = 0 self.message_count = 0
@ -411,6 +421,51 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertEqual(self.received_message, 'Test5') self.assertEqual(self.received_message, 'Test5')
self.assertEqual(self.message_count, 2) self.assertEqual(self.message_count, 2)
def test_multiple_dispatchers(self):
def get_dispatcher_name(q):
q.put(current_thread().name)
sleep(1.2)
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
q1 = Queue()
q2 = Queue()
try:
d1.run_async(get_dispatcher_name, q1)
d2.run_async(get_dispatcher_name, q2)
name1 = q1.get()
name2 = q2.get()
self.assertNotEqual(name1, name2)
finally:
d1.stop()
d2.stop()
# following three lines are for pypy unitests
d1._reset_singleton()
del d1
del d2
def test_multiple_dispatcers_no_decorator(self):
@run_async
def must_raise_runtime_error():
pass
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
self.assertRaises(RuntimeError, must_raise_runtime_error)
d1.stop()
d2.stop()
# following three lines are for pypy unitests
d1._reset_singleton()
del d1
del d2
def test_additionalArgs(self): def test_additionalArgs(self):
self._setup_updater('', messages=0) self._setup_updater('', messages=0)
handler = StringCommandHandler( handler = StringCommandHandler(