Reusable dispatcher (#402)

* Create a Request class which maintains its own connection pool
* When creating a Bot instance a new Request instance will be created if one wasn't supplied.
* Updater is responsible for creating a Request instance if a Bot instance wasn't provided.
* Dispatcher: add method to run async functions without decorator
* Dispatcher can now run as a singleton (allowing run_async decorator to work) as it always did and as multiple instances (where run_async decorator will raise RuntimeError)
This commit is contained in:
Noam Meltzer 2016-09-06 16:38:07 +03:00 committed by GitHub
parent ca81a75f29
commit e4a132c0e4
11 changed files with 420 additions and 331 deletions

View file

@ -1,15 +1,15 @@
- repo: git://github.com/pre-commit/mirrors-yapf
sha: 34303f2856d4e4ba26dc302d9c28632e9b5a8626
sha: v0.11.0
hooks:
- id: yapf
files: ^(telegram|tests)/.*\.py$
- repo: git://github.com/pre-commit/pre-commit-hooks
sha: 3fa02652357ff0dbb42b5bc78c673b7bc105fcf3
sha: 18d7035de5388cc7775be57f529c154bf541aab9
hooks:
- id: flake8
files: ^telegram/.*\.py$
- repo: git://github.com/pre-commit/mirrors-pylint
sha: 4de6c8dfadef1a271a814561ce05b8bc1c446d22
sha: v1.5.5
hooks:
- id: pylint
files: ^telegram/.*\.py$

View file

@ -25,7 +25,7 @@ import logging
from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File,
ReplyMarkup, TelegramObject)
from telegram.error import InvalidToken
from telegram.utils import request
from telegram.utils.request import Request
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -44,10 +44,11 @@ class Bot(TelegramObject):
token (str): Bot's unique authentication.
base_url (Optional[str]): Telegram Bot API service URL.
base_file_url (Optional[str]): Telegram Bot API file URL.
request (Optional[Request]): Pre initialized `Request` class.
"""
def __init__(self, token, base_url=None, base_file_url=None):
def __init__(self, token, base_url=None, base_file_url=None, request=None):
self.token = self._validate_token(token)
if not base_url:
@ -61,9 +62,13 @@ class Bot(TelegramObject):
self.base_file_url = base_file_url + self.token
self.bot = None
self._request = request or Request()
self.logger = logging.getLogger(__name__)
@property
def request(self):
return self._request
@staticmethod
def _validate_token(token):
"""a very basic validation on token"""
@ -144,7 +149,7 @@ class Bot(TelegramObject):
else:
data['reply_markup'] = reply_markup
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
if result is True:
return result
@ -169,7 +174,7 @@ class Bot(TelegramObject):
url = '{0}/getMe'.format(self.base_url)
result = request.get(url)
result = self._request.get(url)
self.bot = User.de_json(result)
@ -813,7 +818,7 @@ class Bot(TelegramObject):
if switch_pm_parameter:
data['switch_pm_parameter'] = switch_pm_parameter
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -853,7 +858,7 @@ class Bot(TelegramObject):
if limit:
data['limit'] = limit
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return UserProfilePhotos.de_json(result)
@ -884,12 +889,12 @@ class Bot(TelegramObject):
data = {'file_id': file_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
if result.get('file_path'):
result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path'])
return File.de_json(result)
return File.de_json(result, self._request)
@log
def kickChatMember(self, chat_id, user_id, **kwargs):
@ -921,7 +926,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -955,7 +960,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -999,7 +1004,7 @@ class Bot(TelegramObject):
if show_alert:
data['show_alert'] = show_alert
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -1213,7 +1218,7 @@ class Bot(TelegramObject):
urlopen_timeout = timeout + network_delay
result = request.post(url, data, timeout=urlopen_timeout)
result = self._request.post(url, data, timeout=urlopen_timeout)
if result:
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
@ -1256,7 +1261,7 @@ class Bot(TelegramObject):
if certificate:
data['certificate'] = certificate
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -1286,7 +1291,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -1318,7 +1323,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return Chat.de_json(result)
@ -1353,7 +1358,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return [ChatMember.de_json(x) for x in result]
@ -1383,7 +1388,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return result
@ -1416,7 +1421,7 @@ class Bot(TelegramObject):
data = {'chat_id': chat_id, 'user_id': user_id}
result = request.post(url, data, timeout=kwargs.get('timeout'))
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
return ChatMember.de_json(result)

View file

@ -19,70 +19,43 @@
"""This module contains the Dispatcher class."""
import logging
import weakref
from functools import wraps
from threading import Thread, Lock, Event, current_thread
from threading import Thread, Lock, Event, current_thread, BoundedSemaphore
from time import sleep
from uuid import uuid4
from queue import Queue, Empty
from future.builtins import range
from telegram import TelegramError
from telegram.utils import request
from telegram.ext.handler import Handler
from telegram.utils.deprecate import deprecate
from telegram.utils.promise import Promise
logging.getLogger(__name__).addHandler(logging.NullHandler())
ASYNC_QUEUE = Queue()
ASYNC_THREADS = set()
""":type: set[Thread]"""
ASYNC_LOCK = Lock() # guards ASYNC_THREADS
DEFAULT_GROUP = 0
def _pooled():
"""
A wrapper to run a thread in a thread pool
"""
while 1:
promise = ASYNC_QUEUE.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
if not isinstance(promise, Promise):
logging.getLogger(__name__).debug("Closing run_async thread %s/%d" %
(current_thread().getName(), len(ASYNC_THREADS)))
break
try:
promise.run()
except:
logging.getLogger(__name__).exception("run_async function raised exception")
def run_async(func):
"""
Function decorator that will run the function in a new thread.
"""Function decorator that will run the function in a new thread.
Using this decorator is only possible when only a single Dispatcher exist in the system.
Args:
func (function): The function to run in the thread.
async_queue (Queue): The queue of the functions to be executed asynchronously.
Returns:
function:
"""
# TODO: handle exception in async threads
# set a threading.Event to notify caller thread
"""
@wraps(func)
def async_func(*args, **kwargs):
"""
A wrapper to run a function in a thread
"""
promise = Promise(func, args, kwargs)
ASYNC_QUEUE.put(promise)
return promise
return Dispatcher.get_instance().run_async(func, *args, **kwargs)
return async_func
@ -100,7 +73,12 @@ class Dispatcher(object):
callbacks
workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async``
decorator
"""
__singleton_lock = Lock()
__singleton_semaphore = BoundedSemaphore()
__singleton = None
logger = logging.getLogger(__name__)
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None):
self.bot = bot
@ -113,28 +91,92 @@ class Dispatcher(object):
""":type: list[int]"""
self.error_handlers = []
self.logger = logging.getLogger(__name__)
self.running = False
self.__stop_event = Event()
self.__exception_event = exception_event or Event()
self.__async_queue = Queue()
self.__async_threads = set()
with ASYNC_LOCK:
if not ASYNC_THREADS:
if request.is_con_pool_initialized():
raise RuntimeError('Connection Pool already initialized')
# we need a connection pool the size of:
# * for each of the workers
# * 1 for Dispatcher
# * 1 for polling Updater (even if updater is webhook, we can spare a connection)
# * 1 for JobQueue
request.CON_POOL_SIZE = workers + 3
for i in range(workers):
thread = Thread(target=_pooled, name=str(i))
ASYNC_THREADS.add(thread)
thread.start()
# For backward compatibility, we allow a "singleton" mode for the dispatcher. When there's
# only one instance of Dispatcher, it will be possible to use the `run_async` decorator.
with self.__singleton_lock:
if self.__singleton_semaphore.acquire(blocking=0):
self._set_singleton(self)
else:
self.logger.debug('Thread pool already initialized, skipping.')
self._set_singleton(None)
self._init_async_threads(uuid4(), workers)
@classmethod
def _reset_singleton(cls):
# NOTE: This method was added mainly for test_updater benefit and specifically pypy. Never
# call it in production code.
cls.__singleton_semaphore.release()
def _init_async_threads(self, base_name, workers):
base_name = '{}_'.format(base_name) if base_name else ''
for i in range(workers):
thread = Thread(target=self._pooled, name='{}{}'.format(base_name, i))
self.__async_threads.add(thread)
thread.start()
@classmethod
def _set_singleton(cls, val):
cls.logger.debug('Setting singleton dispatcher as %s', val)
cls.__singleton = weakref.ref(val) if val else None
@classmethod
def get_instance(cls):
"""Get the singleton instance of this class.
Returns:
Dispatcher
"""
if cls.__singleton is not None:
return cls.__singleton()
else:
raise RuntimeError('{} not initialized or multiple instances exist'.format(
cls.__name__))
def _pooled(self):
"""
A wrapper to run a thread in a thread pool
"""
thr_name = current_thread().getName()
while 1:
promise = self.__async_queue.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
if not isinstance(promise, Promise):
self.logger.debug("Closing run_async thread %s/%d", thr_name,
len(self.__async_threads))
break
try:
promise.run()
except:
self.logger.exception("run_async function raised exception")
def run_async(self, func, *args, **kwargs):
"""Queue a function (with given args/kwargs) to be run asynchronously.
Args:
func (function): The function to run in the thread.
args (Optional[tuple]): Arguments to `func`.
kwargs (Optional[dict]): Keyword arguments to `func`.
Returns:
Promise
"""
# TODO: handle exception in async threads
# set a threading.Event to notify caller thread
promise = Promise(func, args, kwargs)
self.__async_queue.put(promise)
return promise
def start(self):
"""
@ -183,6 +225,25 @@ class Dispatcher(object):
sleep(0.1)
self.__stop_event.clear()
# async threads must be join()ed only after the dispatcher thread was joined,
# otherwise we can still have new async threads dispatched
threads = list(self.__async_threads)
total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
self.__async_queue.put(None)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join()
self.__async_threads.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
@property
def has_running_threads(self):
return self.running or bool(self.__async_threads)
def process_update(self, update):
"""
Processes a single update.

View file

@ -29,8 +29,9 @@ from signal import signal, SIGINT, SIGTERM, SIGABRT
from queue import Queue
from telegram import Bot, TelegramError
from telegram.ext import dispatcher, Dispatcher, JobQueue
from telegram.ext import Dispatcher, JobQueue
from telegram.error import Unauthorized, InvalidToken
from telegram.utils.request import Request
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -57,13 +58,17 @@ class Updater(object):
base_url (Optional[str]):
workers (Optional[int]): Amount of threads in the thread pool for
functions decorated with @run_async
bot (Optional[Bot]):
bot (Optional[Bot]): A pre-initialized bot instance. If a pre-initizlied bot is used, it is
the user's responsibility to create it using a `Request` instance with a large enough
connection pool.
job_queue_tick_interval(Optional[float]): The interval the queue should
be checked for new tasks. Defaults to 1.0
Raises:
ValueError: If both `token` and `bot` are passed or none of them.
"""
_request = None
def __init__(self, token=None, base_url=None, workers=4, bot=None):
if (token is None) and (bot is None):
@ -74,7 +79,14 @@ class Updater(object):
if bot is not None:
self.bot = bot
else:
self.bot = Bot(token, base_url)
# we need a connection pool the size of:
# * for each of the workers
# * 1 for Dispatcher
# * 1 for polling Updater (even if webhook is used, we can spare a connection)
# * 1 for JobQueue
# * 1 for main thread
self._request = Request(con_pool_size=workers + 4)
self.bot = Bot(token, base_url, request=self._request)
self.update_queue = Queue()
self.job_queue = JobQueue(self.bot)
self.__exception_event = Event()
@ -344,7 +356,7 @@ class Updater(object):
self.job_queue.stop()
with self.__lock:
if self.running or dispatcher.ASYNC_THREADS:
if self.running or self.dispatcher.has_running_threads:
self.logger.debug('Stopping Updater and Dispatcher...')
self.running = False
@ -352,9 +364,10 @@ class Updater(object):
self._stop_httpd()
self._stop_dispatcher()
self._join_threads()
# async threads must be join()ed only after the dispatcher thread was joined,
# otherwise we can still have new async threads dispatched
self._join_async_threads()
# Stop the Request instance only if it was created by the Updater
if self._request:
self._request.stop()
def _stop_httpd(self):
if self.httpd:
@ -368,21 +381,6 @@ class Updater(object):
self.logger.debug('Requesting Dispatcher to stop...')
self.dispatcher.stop()
def _join_async_threads(self):
with dispatcher.ASYNC_LOCK:
threads = list(dispatcher.ASYNC_THREADS)
total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
dispatcher.ASYNC_QUEUE.put(None)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join()
dispatcher.ASYNC_THREADS.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
def _join_threads(self):
for thr in self.__threads:
self.logger.debug('Waiting for {0} thread to end'.format(thr.name))

View file

@ -21,7 +21,6 @@
from os.path import basename
from telegram import TelegramObject
from telegram.utils.request import download as _download
class File(TelegramObject):
@ -34,38 +33,44 @@ class File(TelegramObject):
Args:
file_id (str):
request (telegram.utils.request.Request):
**kwargs: Arbitrary keyword arguments.
Keyword Args:
file_size (Optional[int]):
file_path (Optional[str]):
"""
def __init__(self, file_id, **kwargs):
def __init__(self, file_id, request, **kwargs):
# Required
self.file_id = str(file_id)
self._request = request
# Optionals
self.file_size = int(kwargs.get('file_size', 0))
self.file_path = str(kwargs.get('file_path', ''))
@staticmethod
def de_json(data):
def de_json(data, request):
"""
Args:
data (str):
data (dict):
request (telegram.utils.request.Request):
Returns:
telegram.File:
"""
if not data:
return None
return File(**data)
return File(request=request, **data)
def download(self, custom_path=None):
"""
Args:
custom_path (str):
"""
url = self.file_path
@ -74,4 +79,4 @@ class File(TelegramObject):
else:
filename = basename(url)
_download(url, filename)
self._request.download(url, filename)

View file

@ -33,232 +33,185 @@ from urllib3.connection import HTTPConnection
from telegram import (InputFile, TelegramError)
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated
_CON_POOL = None
""":type: urllib3.PoolManager"""
_CON_POOL_PROXY = None
_CON_POOL_PROXY_KWARGS = {}
CON_POOL_SIZE = 1
logging.getLogger('urllib3').setLevel(logging.WARNING)
def _get_con_pool():
if _CON_POOL is not None:
return _CON_POOL
_init_con_pool()
return _CON_POOL
def _init_con_pool():
global _CON_POOL
kwargs = dict(
maxsize=CON_POOL_SIZE,
cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
proxy_url = _get_con_pool_proxy()
if not proxy_url:
mgr = urllib3.PoolManager(**kwargs)
else:
if _CON_POOL_PROXY_KWARGS:
kwargs.update(_CON_POOL_PROXY_KWARGS)
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
if mgr.proxy.auth:
# TODO: what about other auth types?
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
mgr.proxy_headers.update(auth_hdrs)
_CON_POOL = mgr
def is_con_pool_initialized():
return _CON_POOL is not None
def stop_con_pool():
global _CON_POOL
if _CON_POOL is not None:
_CON_POOL.clear()
_CON_POOL = None
def set_con_pool_proxy(url, **urllib3_kwargs):
"""Setup connection pool behind a proxy
class Request(object):
"""
Helper class for python-telegram-bot which provides methods to perform POST & GET towards
telegram servers.
Args:
url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`
urllib3_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`
"""
global _CON_POOL_PROXY
global _CON_POOL_PROXY_KWARGS
if is_con_pool_initialized():
raise TelegramError('conpool already initialized')
_CON_POOL_PROXY = url
_CON_POOL_PROXY_KWARGS = urllib3_kwargs
def _get_con_pool_proxy():
"""Return the user configured proxy according to the following order:
* proxy configured using `set_con_pool_proxy()`.
* proxy set in `HTTPS_PROXY` env. var.
* proxy set in `https_proxy` env. var.
* None (if no proxy is configured)
Returns:
str | None
"""
if _CON_POOL_PROXY:
return _CON_POOL_PROXY
from_env = os.environ.get('HTTPS_PROXY')
if from_env:
return from_env
from_env = os.environ.get('https_proxy')
if from_env:
return from_env
return None
def _parse(json_data):
"""Try and parse the JSON returned from Telegram.
Returns:
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
"""
decoded_s = json_data.decode('utf-8')
try:
data = json.loads(decoded_s)
except ValueError:
raise TelegramError('Invalid server response')
if not data.get('ok'):
description = data.get('description')
parameters = data.get('parameters')
if parameters:
migrate_to_chat_id = parameters.get('migrate_to_chat_id')
if migrate_to_chat_id:
raise ChatMigrated(migrate_to_chat_id)
if description:
return description
return data['result']
def _request_wrapper(*args, **kwargs):
"""Wraps urllib3 request for handling known exceptions.
Args:
args: unnamed arguments, passed to urllib3 request.
kwargs: keyword arguments, passed tp urllib3 request.
Returns:
str: A non-parsed JSON text.
Raises:
TelegramError
proxy_url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`.
urllib3_proxy_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`.
This value will be ignored if proxy_url is not set.
"""
try:
resp = _get_con_pool().request(*args, **kwargs)
except urllib3.exceptions.TimeoutError as error:
raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
def __init__(self, con_pool_size=1, proxy_url=None, urllib3_proxy_kwargs=None):
if urllib3_proxy_kwargs is None:
urllib3_proxy_kwargs = dict()
if 200 <= resp.status <= 299:
# 200-299 range are HTTP success statuses
return resp.data
kwargs = dict(
maxsize=con_pool_size,
cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
try:
message = _parse(resp.data)
except ValueError:
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
# Set a proxy according to the following order:
# * proxy defined in proxy_url (+ urllib3_proxy_kwargs)
# * proxy set in `HTTPS_PROXY` env. var.
# * proxy set in `https_proxy` env. var.
# * None (if no proxy is configured)
if resp.status in (401, 403):
raise Unauthorized()
elif resp.status == 400:
raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
if not proxy_url:
proxy_url = os.environ.get('HTTPS_PROXY') or os.environ.get('https_proxy')
if not proxy_url:
mgr = urllib3.PoolManager(**kwargs)
else:
kwargs.update(urllib3_proxy_kwargs)
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
if mgr.proxy.auth:
# TODO: what about other auth types?
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
mgr.proxy_headers.update(auth_hdrs)
def get(url):
"""Request an URL.
Args:
url:
The web location we want to retrieve.
self._con_pool = mgr
Returns:
A JSON object.
def stop(self):
self._con_pool.clear()
"""
result = _request_wrapper('GET', url)
@staticmethod
def _parse(json_data):
"""Try and parse the JSON returned from Telegram.
return _parse(result)
Returns:
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
"""
decoded_s = json_data.decode('utf-8')
try:
data = json.loads(decoded_s)
except ValueError:
raise TelegramError('Invalid server response')
def post(url, data, timeout=None):
"""Request an URL.
Args:
url:
The web location we want to retrieve.
data:
A dict of (str, unicode) key/value pairs.
timeout:
float. If this value is specified, use it as the definitive timeout (in
seconds) for urlopen() operations. [Optional]
if not data.get('ok'):
description = data.get('description')
parameters = data.get('parameters')
if parameters:
migrate_to_chat_id = parameters.get('migrate_to_chat_id')
if migrate_to_chat_id:
raise ChatMigrated(migrate_to_chat_id)
if description:
return description
Notes:
If neither `timeout` nor `data['timeout']` is specified. The underlying
defaults are used.
return data['result']
Returns:
A JSON object.
def _request_wrapper(self, *args, **kwargs):
"""Wraps urllib3 request for handling known exceptions.
"""
urlopen_kwargs = {}
Args:
args: unnamed arguments, passed to urllib3 request.
kwargs: keyword arguments, passed tp urllib3 request.
if timeout is not None:
urlopen_kwargs['timeout'] = timeout
Returns:
str: A non-parsed JSON text.
if InputFile.is_inputfile(data):
data = InputFile(data)
result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else:
data = json.dumps(data)
result = _request_wrapper(
'POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'},
**urlopen_kwargs)
Raises:
TelegramError
return _parse(result)
"""
try:
resp = self._con_pool.request(*args, **kwargs)
except urllib3.exceptions.TimeoutError:
raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
if 200 <= resp.status <= 299:
# 200-299 range are HTTP success statuses
return resp.data
def download(url, filename):
"""Download a file by its URL.
Args:
url:
The web location we want to retrieve.
try:
message = self._parse(resp.data)
except ValueError:
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
filename:
The filename within the path to download the file.
if resp.status in (401, 403):
raise Unauthorized()
elif resp.status == 400:
raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
"""
buf = _request_wrapper('GET', url)
with open(filename, 'wb') as fobj:
fobj.write(buf)
def get(self, url):
"""Request an URL.
Args:
url:
The web location we want to retrieve.
Returns:
A JSON object.
"""
result = self._request_wrapper('GET', url)
return self._parse(result)
def post(self, url, data, timeout=None):
"""Request an URL.
Args:
url:
The web location we want to retrieve.
data:
A dict of (str, unicode) key/value pairs.
timeout:
float. If this value is specified, use it as the definitive timeout (in
seconds) for urlopen() operations. [Optional]
Notes:
If neither `timeout` nor `data['timeout']` is specified. The underlying
defaults are used.
Returns:
A JSON object.
"""
urlopen_kwargs = {}
if timeout is not None:
urlopen_kwargs['timeout'] = timeout
if InputFile.is_inputfile(data):
data = InputFile(data)
result = self._request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else:
data = json.dumps(data)
result = self._request_wrapper(
'POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'},
**urlopen_kwargs)
return self._parse(result)
def download(self, url, filename):
"""Download a file by its URL.
Args:
url:
The web location we want to retrieve.
filename:
The filename within the path to download the file.
"""
buf = self._request_wrapper('GET', url)
with open(filename, 'wb') as fobj:
fobj.write(buf)

View file

@ -198,7 +198,9 @@ class BotTest(BaseTest, unittest.TestCase):
def testInvalidSrvResp(self):
with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'):
# bypass the valid token check
bot = telegram.Bot.__new__(telegram.Bot)
newbot_cls = type(
'NoTokenValidateBot', (telegram.Bot,), dict(_validate_token=lambda x, y: None))
bot = newbot_cls('0xdeadbeef')
bot.base_url = 'https://api.telegram.org/bot{0}'.format('12')
bot.getMe()

View file

@ -36,7 +36,6 @@ except ImportError:
sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot
from telegram.utils.request import stop_con_pool
from telegram.ext import *
from tests.base import BaseTest
from tests.test_updater import MockBot
@ -61,10 +60,10 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
# At first we're thirsty. Then we brew coffee, we drink it
# and then we can start coding!
END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4)
_updater = None
# Test related
def setUp(self):
self.updater = None
self.current_state = dict()
self.entry_points = [CommandHandler('start', self.start)]
self.states = {self.THIRSTY: [CommandHandler('brew', self.brew),
@ -78,14 +77,22 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
self.fallbacks = [CommandHandler('eat', self.start)]
def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot)
def tearDown(self):
if self.updater is not None:
self.updater.stop()
stop_con_pool()
@property
def updater(self):
return self._updater
@updater.setter
def updater(self, val):
if self._updater:
self._updater.stop()
self._updater = val
def reset(self):
self.current_state = dict()

View file

@ -101,19 +101,19 @@ class FileTest(BaseTest, unittest.TestCase):
self.assertTrue(os.path.isfile('telegram.ogg'))
def test_file_de_json(self):
newFile = telegram.File.de_json(self.json_dict)
newFile = telegram.File.de_json(self.json_dict, None)
self.assertEqual(newFile.file_id, self.json_dict['file_id'])
self.assertEqual(newFile.file_path, self.json_dict['file_path'])
self.assertEqual(newFile.file_size, self.json_dict['file_size'])
def test_file_to_json(self):
newFile = telegram.File.de_json(self.json_dict)
newFile = telegram.File.de_json(self.json_dict, None)
self.assertTrue(self.is_json(newFile.to_json()))
def test_file_to_dict(self):
newFile = telegram.File.de_json(self.json_dict)
newFile = telegram.File.de_json(self.json_dict, None)
self.assertTrue(self.is_dict(newFile.to_dict()))
self.assertEqual(newFile['file_id'], self.json_dict['file_id'])

View file

@ -25,9 +25,10 @@ import sys
import unittest
from time import sleep
from tests.test_updater import MockBot
sys.path.append('.')
from telegram.utils.request import stop_con_pool
from telegram.ext import JobQueue, Job, Updater
from tests.base import BaseTest
@ -49,13 +50,12 @@ class JobQueueTest(BaseTest, unittest.TestCase):
"""
def setUp(self):
self.jq = JobQueue("Bot")
self.jq = JobQueue(MockBot('jobqueue_test'))
self.result = 0
def tearDown(self):
if self.jq is not None:
self.jq.stop()
stop_con_pool()
def job1(self, bot, job):
self.result += 1
@ -158,12 +158,15 @@ class JobQueueTest(BaseTest, unittest.TestCase):
def test_inUpdater(self):
u = Updater(bot="MockBot")
u.job_queue.put(Job(self.job1, 0.5))
sleep(0.75)
self.assertEqual(1, self.result)
u.stop()
sleep(2)
self.assertEqual(1, self.result)
try:
u.job_queue.put(Job(self.job1, 0.5))
sleep(0.75)
self.assertEqual(1, self.result)
u.stop()
sleep(2)
self.assertEqual(1, self.result)
finally:
u.stop()
if __name__ == '__main__':

View file

@ -29,10 +29,13 @@ import re
import unittest
from datetime import datetime
from time import sleep
from queue import Queue
from random import randrange
from future.builtins import bytes
from telegram.utils.request import Request as Requester
try:
# python2
from urllib2 import urlopen, Request, HTTPError
@ -44,12 +47,11 @@ except ImportError:
sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery
from telegram.utils.request import stop_con_pool
from telegram.ext import *
from telegram.ext.dispatcher import run_async
from telegram.error import Unauthorized, InvalidToken
from tests.base import BaseTest
from threading import Lock, Thread
from threading import Lock, Thread, current_thread, Semaphore
# Enable logging
root = logging.getLogger()
@ -68,10 +70,8 @@ class UpdaterTest(BaseTest, unittest.TestCase):
WebhookHandler
"""
updater = None
_updater = None
received_message = None
message_count = None
lock = None
def setUp(self):
self.updater = None
@ -79,15 +79,25 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.message_count = 0
self.lock = Lock()
@property
def updater(self):
return self._updater
@updater.setter
def updater(self, val):
if self._updater:
self._updater.stop()
self._updater.dispatcher._reset_singleton()
del self._updater.dispatcher
self._updater = val
def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot)
def tearDown(self):
if self.updater is not None:
self.updater.stop()
stop_con_pool()
self.updater = None
def reset(self):
self.message_count = 0
@ -411,6 +421,51 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertEqual(self.received_message, 'Test5')
self.assertEqual(self.message_count, 2)
def test_multiple_dispatchers(self):
def get_dispatcher_name(q):
q.put(current_thread().name)
sleep(1.2)
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
q1 = Queue()
q2 = Queue()
try:
d1.run_async(get_dispatcher_name, q1)
d2.run_async(get_dispatcher_name, q2)
name1 = q1.get()
name2 = q2.get()
self.assertNotEqual(name1, name2)
finally:
d1.stop()
d2.stop()
# following three lines are for pypy unitests
d1._reset_singleton()
del d1
del d2
def test_multiple_dispatcers_no_decorator(self):
@run_async
def must_raise_runtime_error():
pass
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
self.assertRaises(RuntimeError, must_raise_runtime_error)
d1.stop()
d2.stop()
# following three lines are for pypy unitests
d1._reset_singleton()
del d1
del d2
def test_additionalArgs(self):
self._setup_updater('', messages=0)
handler = StringCommandHandler(