mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-10 20:12:52 +01:00
Reusable dispatcher (#402)
* Create a Request class which maintains its own connection pool * When creating a Bot instance a new Request instance will be created if one wasn't supplied. * Updater is responsible for creating a Request instance if a Bot instance wasn't provided. * Dispatcher: add method to run async functions without decorator * Dispatcher can now run as a singleton (allowing run_async decorator to work) as it always did and as multiple instances (where run_async decorator will raise RuntimeError)
This commit is contained in:
parent
ca81a75f29
commit
e4a132c0e4
11 changed files with 420 additions and 331 deletions
|
@ -1,15 +1,15 @@
|
||||||
- repo: git://github.com/pre-commit/mirrors-yapf
|
- repo: git://github.com/pre-commit/mirrors-yapf
|
||||||
sha: 34303f2856d4e4ba26dc302d9c28632e9b5a8626
|
sha: v0.11.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
files: ^(telegram|tests)/.*\.py$
|
files: ^(telegram|tests)/.*\.py$
|
||||||
- repo: git://github.com/pre-commit/pre-commit-hooks
|
- repo: git://github.com/pre-commit/pre-commit-hooks
|
||||||
sha: 3fa02652357ff0dbb42b5bc78c673b7bc105fcf3
|
sha: 18d7035de5388cc7775be57f529c154bf541aab9
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
files: ^telegram/.*\.py$
|
files: ^telegram/.*\.py$
|
||||||
- repo: git://github.com/pre-commit/mirrors-pylint
|
- repo: git://github.com/pre-commit/mirrors-pylint
|
||||||
sha: 4de6c8dfadef1a271a814561ce05b8bc1c446d22
|
sha: v1.5.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
files: ^telegram/.*\.py$
|
files: ^telegram/.*\.py$
|
||||||
|
|
|
@ -25,7 +25,7 @@ import logging
|
||||||
from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File,
|
from telegram import (User, Message, Update, Chat, ChatMember, UserProfilePhotos, File,
|
||||||
ReplyMarkup, TelegramObject)
|
ReplyMarkup, TelegramObject)
|
||||||
from telegram.error import InvalidToken
|
from telegram.error import InvalidToken
|
||||||
from telegram.utils import request
|
from telegram.utils.request import Request
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
|
||||||
|
@ -44,10 +44,11 @@ class Bot(TelegramObject):
|
||||||
token (str): Bot's unique authentication.
|
token (str): Bot's unique authentication.
|
||||||
base_url (Optional[str]): Telegram Bot API service URL.
|
base_url (Optional[str]): Telegram Bot API service URL.
|
||||||
base_file_url (Optional[str]): Telegram Bot API file URL.
|
base_file_url (Optional[str]): Telegram Bot API file URL.
|
||||||
|
request (Optional[Request]): Pre initialized `Request` class.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, token, base_url=None, base_file_url=None):
|
def __init__(self, token, base_url=None, base_file_url=None, request=None):
|
||||||
self.token = self._validate_token(token)
|
self.token = self._validate_token(token)
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
|
@ -61,9 +62,13 @@ class Bot(TelegramObject):
|
||||||
self.base_file_url = base_file_url + self.token
|
self.base_file_url = base_file_url + self.token
|
||||||
|
|
||||||
self.bot = None
|
self.bot = None
|
||||||
|
self._request = request or Request()
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def request(self):
|
||||||
|
return self._request
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_token(token):
|
def _validate_token(token):
|
||||||
"""a very basic validation on token"""
|
"""a very basic validation on token"""
|
||||||
|
@ -144,7 +149,7 @@ class Bot(TelegramObject):
|
||||||
else:
|
else:
|
||||||
data['reply_markup'] = reply_markup
|
data['reply_markup'] = reply_markup
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
if result is True:
|
if result is True:
|
||||||
return result
|
return result
|
||||||
|
@ -169,7 +174,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
url = '{0}/getMe'.format(self.base_url)
|
url = '{0}/getMe'.format(self.base_url)
|
||||||
|
|
||||||
result = request.get(url)
|
result = self._request.get(url)
|
||||||
|
|
||||||
self.bot = User.de_json(result)
|
self.bot = User.de_json(result)
|
||||||
|
|
||||||
|
@ -813,7 +818,7 @@ class Bot(TelegramObject):
|
||||||
if switch_pm_parameter:
|
if switch_pm_parameter:
|
||||||
data['switch_pm_parameter'] = switch_pm_parameter
|
data['switch_pm_parameter'] = switch_pm_parameter
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -853,7 +858,7 @@ class Bot(TelegramObject):
|
||||||
if limit:
|
if limit:
|
||||||
data['limit'] = limit
|
data['limit'] = limit
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return UserProfilePhotos.de_json(result)
|
return UserProfilePhotos.de_json(result)
|
||||||
|
|
||||||
|
@ -884,12 +889,12 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'file_id': file_id}
|
data = {'file_id': file_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
if result.get('file_path'):
|
if result.get('file_path'):
|
||||||
result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path'])
|
result['file_path'] = '%s/%s' % (self.base_file_url, result['file_path'])
|
||||||
|
|
||||||
return File.de_json(result)
|
return File.de_json(result, self._request)
|
||||||
|
|
||||||
@log
|
@log
|
||||||
def kickChatMember(self, chat_id, user_id, **kwargs):
|
def kickChatMember(self, chat_id, user_id, **kwargs):
|
||||||
|
@ -921,7 +926,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id, 'user_id': user_id}
|
data = {'chat_id': chat_id, 'user_id': user_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -955,7 +960,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id, 'user_id': user_id}
|
data = {'chat_id': chat_id, 'user_id': user_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -999,7 +1004,7 @@ class Bot(TelegramObject):
|
||||||
if show_alert:
|
if show_alert:
|
||||||
data['show_alert'] = show_alert
|
data['show_alert'] = show_alert
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1213,7 +1218,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
urlopen_timeout = timeout + network_delay
|
urlopen_timeout = timeout + network_delay
|
||||||
|
|
||||||
result = request.post(url, data, timeout=urlopen_timeout)
|
result = self._request.post(url, data, timeout=urlopen_timeout)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
|
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
|
||||||
|
@ -1256,7 +1261,7 @@ class Bot(TelegramObject):
|
||||||
if certificate:
|
if certificate:
|
||||||
data['certificate'] = certificate
|
data['certificate'] = certificate
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1286,7 +1291,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id}
|
data = {'chat_id': chat_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1318,7 +1323,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id}
|
data = {'chat_id': chat_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return Chat.de_json(result)
|
return Chat.de_json(result)
|
||||||
|
|
||||||
|
@ -1353,7 +1358,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id}
|
data = {'chat_id': chat_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return [ChatMember.de_json(x) for x in result]
|
return [ChatMember.de_json(x) for x in result]
|
||||||
|
|
||||||
|
@ -1383,7 +1388,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id}
|
data = {'chat_id': chat_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1416,7 +1421,7 @@ class Bot(TelegramObject):
|
||||||
|
|
||||||
data = {'chat_id': chat_id, 'user_id': user_id}
|
data = {'chat_id': chat_id, 'user_id': user_id}
|
||||||
|
|
||||||
result = request.post(url, data, timeout=kwargs.get('timeout'))
|
result = self._request.post(url, data, timeout=kwargs.get('timeout'))
|
||||||
|
|
||||||
return ChatMember.de_json(result)
|
return ChatMember.de_json(result)
|
||||||
|
|
||||||
|
|
|
@ -19,70 +19,43 @@
|
||||||
"""This module contains the Dispatcher class."""
|
"""This module contains the Dispatcher class."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import weakref
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from threading import Thread, Lock, Event, current_thread
|
from threading import Thread, Lock, Event, current_thread, BoundedSemaphore
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
|
|
||||||
from future.builtins import range
|
from future.builtins import range
|
||||||
|
|
||||||
from telegram import TelegramError
|
from telegram import TelegramError
|
||||||
from telegram.utils import request
|
|
||||||
from telegram.ext.handler import Handler
|
from telegram.ext.handler import Handler
|
||||||
from telegram.utils.deprecate import deprecate
|
from telegram.utils.deprecate import deprecate
|
||||||
from telegram.utils.promise import Promise
|
from telegram.utils.promise import Promise
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
|
||||||
ASYNC_QUEUE = Queue()
|
|
||||||
ASYNC_THREADS = set()
|
|
||||||
""":type: set[Thread]"""
|
""":type: set[Thread]"""
|
||||||
ASYNC_LOCK = Lock() # guards ASYNC_THREADS
|
|
||||||
DEFAULT_GROUP = 0
|
DEFAULT_GROUP = 0
|
||||||
|
|
||||||
|
|
||||||
def _pooled():
|
|
||||||
"""
|
|
||||||
A wrapper to run a thread in a thread pool
|
|
||||||
"""
|
|
||||||
while 1:
|
|
||||||
promise = ASYNC_QUEUE.get()
|
|
||||||
|
|
||||||
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
|
|
||||||
if not isinstance(promise, Promise):
|
|
||||||
logging.getLogger(__name__).debug("Closing run_async thread %s/%d" %
|
|
||||||
(current_thread().getName(), len(ASYNC_THREADS)))
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
promise.run()
|
|
||||||
|
|
||||||
except:
|
|
||||||
logging.getLogger(__name__).exception("run_async function raised exception")
|
|
||||||
|
|
||||||
|
|
||||||
def run_async(func):
|
def run_async(func):
|
||||||
"""
|
"""Function decorator that will run the function in a new thread.
|
||||||
Function decorator that will run the function in a new thread.
|
|
||||||
|
Using this decorator is only possible when only a single Dispatcher exist in the system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (function): The function to run in the thread.
|
func (function): The function to run in the thread.
|
||||||
|
async_queue (Queue): The queue of the functions to be executed asynchronously.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function:
|
function:
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: handle exception in async threads
|
"""
|
||||||
# set a threading.Event to notify caller thread
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def async_func(*args, **kwargs):
|
def async_func(*args, **kwargs):
|
||||||
"""
|
return Dispatcher.get_instance().run_async(func, *args, **kwargs)
|
||||||
A wrapper to run a function in a thread
|
|
||||||
"""
|
|
||||||
promise = Promise(func, args, kwargs)
|
|
||||||
ASYNC_QUEUE.put(promise)
|
|
||||||
return promise
|
|
||||||
|
|
||||||
return async_func
|
return async_func
|
||||||
|
|
||||||
|
@ -100,7 +73,12 @@ class Dispatcher(object):
|
||||||
callbacks
|
callbacks
|
||||||
workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async``
|
workers (Optional[int]): Number of maximum concurrent worker threads for the ``@run_async``
|
||||||
decorator
|
decorator
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
__singleton_lock = Lock()
|
||||||
|
__singleton_semaphore = BoundedSemaphore()
|
||||||
|
__singleton = None
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None):
|
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
@ -113,28 +91,92 @@ class Dispatcher(object):
|
||||||
""":type: list[int]"""
|
""":type: list[int]"""
|
||||||
self.error_handlers = []
|
self.error_handlers = []
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self.__stop_event = Event()
|
self.__stop_event = Event()
|
||||||
self.__exception_event = exception_event or Event()
|
self.__exception_event = exception_event or Event()
|
||||||
|
self.__async_queue = Queue()
|
||||||
|
self.__async_threads = set()
|
||||||
|
|
||||||
with ASYNC_LOCK:
|
# For backward compatibility, we allow a "singleton" mode for the dispatcher. When there's
|
||||||
if not ASYNC_THREADS:
|
# only one instance of Dispatcher, it will be possible to use the `run_async` decorator.
|
||||||
if request.is_con_pool_initialized():
|
with self.__singleton_lock:
|
||||||
raise RuntimeError('Connection Pool already initialized')
|
if self.__singleton_semaphore.acquire(blocking=0):
|
||||||
|
self._set_singleton(self)
|
||||||
# we need a connection pool the size of:
|
|
||||||
# * for each of the workers
|
|
||||||
# * 1 for Dispatcher
|
|
||||||
# * 1 for polling Updater (even if updater is webhook, we can spare a connection)
|
|
||||||
# * 1 for JobQueue
|
|
||||||
request.CON_POOL_SIZE = workers + 3
|
|
||||||
for i in range(workers):
|
|
||||||
thread = Thread(target=_pooled, name=str(i))
|
|
||||||
ASYNC_THREADS.add(thread)
|
|
||||||
thread.start()
|
|
||||||
else:
|
else:
|
||||||
self.logger.debug('Thread pool already initialized, skipping.')
|
self._set_singleton(None)
|
||||||
|
|
||||||
|
self._init_async_threads(uuid4(), workers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _reset_singleton(cls):
|
||||||
|
# NOTE: This method was added mainly for test_updater benefit and specifically pypy. Never
|
||||||
|
# call it in production code.
|
||||||
|
cls.__singleton_semaphore.release()
|
||||||
|
|
||||||
|
def _init_async_threads(self, base_name, workers):
|
||||||
|
base_name = '{}_'.format(base_name) if base_name else ''
|
||||||
|
|
||||||
|
for i in range(workers):
|
||||||
|
thread = Thread(target=self._pooled, name='{}{}'.format(base_name, i))
|
||||||
|
self.__async_threads.add(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_singleton(cls, val):
|
||||||
|
cls.logger.debug('Setting singleton dispatcher as %s', val)
|
||||||
|
cls.__singleton = weakref.ref(val) if val else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
"""Get the singleton instance of this class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dispatcher
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cls.__singleton is not None:
|
||||||
|
return cls.__singleton()
|
||||||
|
else:
|
||||||
|
raise RuntimeError('{} not initialized or multiple instances exist'.format(
|
||||||
|
cls.__name__))
|
||||||
|
|
||||||
|
def _pooled(self):
|
||||||
|
"""
|
||||||
|
A wrapper to run a thread in a thread pool
|
||||||
|
"""
|
||||||
|
thr_name = current_thread().getName()
|
||||||
|
while 1:
|
||||||
|
promise = self.__async_queue.get()
|
||||||
|
|
||||||
|
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
|
||||||
|
if not isinstance(promise, Promise):
|
||||||
|
self.logger.debug("Closing run_async thread %s/%d", thr_name,
|
||||||
|
len(self.__async_threads))
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
promise.run()
|
||||||
|
|
||||||
|
except:
|
||||||
|
self.logger.exception("run_async function raised exception")
|
||||||
|
|
||||||
|
def run_async(self, func, *args, **kwargs):
|
||||||
|
"""Queue a function (with given args/kwargs) to be run asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (function): The function to run in the thread.
|
||||||
|
args (Optional[tuple]): Arguments to `func`.
|
||||||
|
kwargs (Optional[dict]): Keyword arguments to `func`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Promise
|
||||||
|
|
||||||
|
"""
|
||||||
|
# TODO: handle exception in async threads
|
||||||
|
# set a threading.Event to notify caller thread
|
||||||
|
promise = Promise(func, args, kwargs)
|
||||||
|
self.__async_queue.put(promise)
|
||||||
|
return promise
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""
|
"""
|
||||||
|
@ -183,6 +225,25 @@ class Dispatcher(object):
|
||||||
sleep(0.1)
|
sleep(0.1)
|
||||||
self.__stop_event.clear()
|
self.__stop_event.clear()
|
||||||
|
|
||||||
|
# async threads must be join()ed only after the dispatcher thread was joined,
|
||||||
|
# otherwise we can still have new async threads dispatched
|
||||||
|
threads = list(self.__async_threads)
|
||||||
|
total = len(threads)
|
||||||
|
|
||||||
|
# Stop all threads in the thread pool by put()ting one non-tuple per thread
|
||||||
|
for i in range(total):
|
||||||
|
self.__async_queue.put(None)
|
||||||
|
|
||||||
|
for i, thr in enumerate(threads):
|
||||||
|
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
|
||||||
|
thr.join()
|
||||||
|
self.__async_threads.remove(thr)
|
||||||
|
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_running_threads(self):
|
||||||
|
return self.running or bool(self.__async_threads)
|
||||||
|
|
||||||
def process_update(self, update):
|
def process_update(self, update):
|
||||||
"""
|
"""
|
||||||
Processes a single update.
|
Processes a single update.
|
||||||
|
|
|
@ -29,8 +29,9 @@ from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
from telegram import Bot, TelegramError
|
from telegram import Bot, TelegramError
|
||||||
from telegram.ext import dispatcher, Dispatcher, JobQueue
|
from telegram.ext import Dispatcher, JobQueue
|
||||||
from telegram.error import Unauthorized, InvalidToken
|
from telegram.error import Unauthorized, InvalidToken
|
||||||
|
from telegram.utils.request import Request
|
||||||
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
|
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
@ -57,13 +58,17 @@ class Updater(object):
|
||||||
base_url (Optional[str]):
|
base_url (Optional[str]):
|
||||||
workers (Optional[int]): Amount of threads in the thread pool for
|
workers (Optional[int]): Amount of threads in the thread pool for
|
||||||
functions decorated with @run_async
|
functions decorated with @run_async
|
||||||
bot (Optional[Bot]):
|
bot (Optional[Bot]): A pre-initialized bot instance. If a pre-initizlied bot is used, it is
|
||||||
|
the user's responsibility to create it using a `Request` instance with a large enough
|
||||||
|
connection pool.
|
||||||
job_queue_tick_interval(Optional[float]): The interval the queue should
|
job_queue_tick_interval(Optional[float]): The interval the queue should
|
||||||
be checked for new tasks. Defaults to 1.0
|
be checked for new tasks. Defaults to 1.0
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If both `token` and `bot` are passed or none of them.
|
ValueError: If both `token` and `bot` are passed or none of them.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
_request = None
|
||||||
|
|
||||||
def __init__(self, token=None, base_url=None, workers=4, bot=None):
|
def __init__(self, token=None, base_url=None, workers=4, bot=None):
|
||||||
if (token is None) and (bot is None):
|
if (token is None) and (bot is None):
|
||||||
|
@ -74,7 +79,14 @@ class Updater(object):
|
||||||
if bot is not None:
|
if bot is not None:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
else:
|
else:
|
||||||
self.bot = Bot(token, base_url)
|
# we need a connection pool the size of:
|
||||||
|
# * for each of the workers
|
||||||
|
# * 1 for Dispatcher
|
||||||
|
# * 1 for polling Updater (even if webhook is used, we can spare a connection)
|
||||||
|
# * 1 for JobQueue
|
||||||
|
# * 1 for main thread
|
||||||
|
self._request = Request(con_pool_size=workers + 4)
|
||||||
|
self.bot = Bot(token, base_url, request=self._request)
|
||||||
self.update_queue = Queue()
|
self.update_queue = Queue()
|
||||||
self.job_queue = JobQueue(self.bot)
|
self.job_queue = JobQueue(self.bot)
|
||||||
self.__exception_event = Event()
|
self.__exception_event = Event()
|
||||||
|
@ -344,7 +356,7 @@ class Updater(object):
|
||||||
|
|
||||||
self.job_queue.stop()
|
self.job_queue.stop()
|
||||||
with self.__lock:
|
with self.__lock:
|
||||||
if self.running or dispatcher.ASYNC_THREADS:
|
if self.running or self.dispatcher.has_running_threads:
|
||||||
self.logger.debug('Stopping Updater and Dispatcher...')
|
self.logger.debug('Stopping Updater and Dispatcher...')
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
|
@ -352,9 +364,10 @@ class Updater(object):
|
||||||
self._stop_httpd()
|
self._stop_httpd()
|
||||||
self._stop_dispatcher()
|
self._stop_dispatcher()
|
||||||
self._join_threads()
|
self._join_threads()
|
||||||
# async threads must be join()ed only after the dispatcher thread was joined,
|
|
||||||
# otherwise we can still have new async threads dispatched
|
# Stop the Request instance only if it was created by the Updater
|
||||||
self._join_async_threads()
|
if self._request:
|
||||||
|
self._request.stop()
|
||||||
|
|
||||||
def _stop_httpd(self):
|
def _stop_httpd(self):
|
||||||
if self.httpd:
|
if self.httpd:
|
||||||
|
@ -368,21 +381,6 @@ class Updater(object):
|
||||||
self.logger.debug('Requesting Dispatcher to stop...')
|
self.logger.debug('Requesting Dispatcher to stop...')
|
||||||
self.dispatcher.stop()
|
self.dispatcher.stop()
|
||||||
|
|
||||||
def _join_async_threads(self):
|
|
||||||
with dispatcher.ASYNC_LOCK:
|
|
||||||
threads = list(dispatcher.ASYNC_THREADS)
|
|
||||||
total = len(threads)
|
|
||||||
|
|
||||||
# Stop all threads in the thread pool by put()ting one non-tuple per thread
|
|
||||||
for i in range(total):
|
|
||||||
dispatcher.ASYNC_QUEUE.put(None)
|
|
||||||
|
|
||||||
for i, thr in enumerate(threads):
|
|
||||||
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
|
|
||||||
thr.join()
|
|
||||||
dispatcher.ASYNC_THREADS.remove(thr)
|
|
||||||
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
|
|
||||||
|
|
||||||
def _join_threads(self):
|
def _join_threads(self):
|
||||||
for thr in self.__threads:
|
for thr in self.__threads:
|
||||||
self.logger.debug('Waiting for {0} thread to end'.format(thr.name))
|
self.logger.debug('Waiting for {0} thread to end'.format(thr.name))
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
from os.path import basename
|
from os.path import basename
|
||||||
|
|
||||||
from telegram import TelegramObject
|
from telegram import TelegramObject
|
||||||
from telegram.utils.request import download as _download
|
|
||||||
|
|
||||||
|
|
||||||
class File(TelegramObject):
|
class File(TelegramObject):
|
||||||
|
@ -34,38 +33,44 @@ class File(TelegramObject):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id (str):
|
file_id (str):
|
||||||
|
request (telegram.utils.request.Request):
|
||||||
**kwargs: Arbitrary keyword arguments.
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
file_size (Optional[int]):
|
file_size (Optional[int]):
|
||||||
file_path (Optional[str]):
|
file_path (Optional[str]):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file_id, **kwargs):
|
def __init__(self, file_id, request, **kwargs):
|
||||||
# Required
|
# Required
|
||||||
self.file_id = str(file_id)
|
self.file_id = str(file_id)
|
||||||
|
self._request = request
|
||||||
# Optionals
|
# Optionals
|
||||||
self.file_size = int(kwargs.get('file_size', 0))
|
self.file_size = int(kwargs.get('file_size', 0))
|
||||||
self.file_path = str(kwargs.get('file_path', ''))
|
self.file_path = str(kwargs.get('file_path', ''))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def de_json(data):
|
def de_json(data, request):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data (str):
|
data (dict):
|
||||||
|
request (telegram.utils.request.Request):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
telegram.File:
|
telegram.File:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not data:
|
if not data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return File(**data)
|
return File(request=request, **data)
|
||||||
|
|
||||||
def download(self, custom_path=None):
|
def download(self, custom_path=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
custom_path (str):
|
custom_path (str):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = self.file_path
|
url = self.file_path
|
||||||
|
|
||||||
|
@ -74,4 +79,4 @@ class File(TelegramObject):
|
||||||
else:
|
else:
|
||||||
filename = basename(url)
|
filename = basename(url)
|
||||||
|
|
||||||
_download(url, filename)
|
self._request.download(url, filename)
|
||||||
|
|
|
@ -33,232 +33,185 @@ from urllib3.connection import HTTPConnection
|
||||||
from telegram import (InputFile, TelegramError)
|
from telegram import (InputFile, TelegramError)
|
||||||
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated
|
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest, ChatMigrated
|
||||||
|
|
||||||
_CON_POOL = None
|
|
||||||
""":type: urllib3.PoolManager"""
|
|
||||||
_CON_POOL_PROXY = None
|
|
||||||
_CON_POOL_PROXY_KWARGS = {}
|
|
||||||
CON_POOL_SIZE = 1
|
|
||||||
|
|
||||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def _get_con_pool():
|
class Request(object):
|
||||||
if _CON_POOL is not None:
|
"""
|
||||||
return _CON_POOL
|
Helper class for python-telegram-bot which provides methods to perform POST & GET towards
|
||||||
|
telegram servers.
|
||||||
_init_con_pool()
|
|
||||||
return _CON_POOL
|
|
||||||
|
|
||||||
|
|
||||||
def _init_con_pool():
|
|
||||||
global _CON_POOL
|
|
||||||
kwargs = dict(
|
|
||||||
maxsize=CON_POOL_SIZE,
|
|
||||||
cert_reqs='CERT_REQUIRED',
|
|
||||||
ca_certs=certifi.where(),
|
|
||||||
socket_options=HTTPConnection.default_socket_options + [
|
|
||||||
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
|
||||||
])
|
|
||||||
proxy_url = _get_con_pool_proxy()
|
|
||||||
if not proxy_url:
|
|
||||||
mgr = urllib3.PoolManager(**kwargs)
|
|
||||||
else:
|
|
||||||
if _CON_POOL_PROXY_KWARGS:
|
|
||||||
kwargs.update(_CON_POOL_PROXY_KWARGS)
|
|
||||||
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
|
|
||||||
if mgr.proxy.auth:
|
|
||||||
# TODO: what about other auth types?
|
|
||||||
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
|
|
||||||
mgr.proxy_headers.update(auth_hdrs)
|
|
||||||
|
|
||||||
_CON_POOL = mgr
|
|
||||||
|
|
||||||
|
|
||||||
def is_con_pool_initialized():
|
|
||||||
return _CON_POOL is not None
|
|
||||||
|
|
||||||
|
|
||||||
def stop_con_pool():
|
|
||||||
global _CON_POOL
|
|
||||||
if _CON_POOL is not None:
|
|
||||||
_CON_POOL.clear()
|
|
||||||
_CON_POOL = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_con_pool_proxy(url, **urllib3_kwargs):
|
|
||||||
"""Setup connection pool behind a proxy
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`
|
proxy_url (str): The URL to the proxy server. For example: `http://127.0.0.1:3128`.
|
||||||
urllib3_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`
|
urllib3_proxy_kwargs (dict): Arbitrary arguments passed as-is to `urllib3.ProxyManager`.
|
||||||
|
This value will be ignored if proxy_url is not set.
|
||||||
"""
|
|
||||||
global _CON_POOL_PROXY
|
|
||||||
global _CON_POOL_PROXY_KWARGS
|
|
||||||
|
|
||||||
if is_con_pool_initialized():
|
|
||||||
raise TelegramError('conpool already initialized')
|
|
||||||
|
|
||||||
_CON_POOL_PROXY = url
|
|
||||||
_CON_POOL_PROXY_KWARGS = urllib3_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_con_pool_proxy():
|
|
||||||
"""Return the user configured proxy according to the following order:
|
|
||||||
|
|
||||||
* proxy configured using `set_con_pool_proxy()`.
|
|
||||||
* proxy set in `HTTPS_PROXY` env. var.
|
|
||||||
* proxy set in `https_proxy` env. var.
|
|
||||||
* None (if no proxy is configured)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None
|
|
||||||
|
|
||||||
"""
|
|
||||||
if _CON_POOL_PROXY:
|
|
||||||
return _CON_POOL_PROXY
|
|
||||||
from_env = os.environ.get('HTTPS_PROXY')
|
|
||||||
if from_env:
|
|
||||||
return from_env
|
|
||||||
from_env = os.environ.get('https_proxy')
|
|
||||||
if from_env:
|
|
||||||
return from_env
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse(json_data):
|
|
||||||
"""Try and parse the JSON returned from Telegram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
|
|
||||||
|
|
||||||
"""
|
|
||||||
decoded_s = json_data.decode('utf-8')
|
|
||||||
try:
|
|
||||||
data = json.loads(decoded_s)
|
|
||||||
except ValueError:
|
|
||||||
raise TelegramError('Invalid server response')
|
|
||||||
|
|
||||||
if not data.get('ok'):
|
|
||||||
description = data.get('description')
|
|
||||||
parameters = data.get('parameters')
|
|
||||||
if parameters:
|
|
||||||
migrate_to_chat_id = parameters.get('migrate_to_chat_id')
|
|
||||||
if migrate_to_chat_id:
|
|
||||||
raise ChatMigrated(migrate_to_chat_id)
|
|
||||||
if description:
|
|
||||||
return description
|
|
||||||
|
|
||||||
return data['result']
|
|
||||||
|
|
||||||
|
|
||||||
def _request_wrapper(*args, **kwargs):
|
|
||||||
"""Wraps urllib3 request for handling known exceptions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: unnamed arguments, passed to urllib3 request.
|
|
||||||
kwargs: keyword arguments, passed tp urllib3 request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A non-parsed JSON text.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TelegramError
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
def __init__(self, con_pool_size=1, proxy_url=None, urllib3_proxy_kwargs=None):
|
||||||
resp = _get_con_pool().request(*args, **kwargs)
|
if urllib3_proxy_kwargs is None:
|
||||||
except urllib3.exceptions.TimeoutError as error:
|
urllib3_proxy_kwargs = dict()
|
||||||
raise TimedOut()
|
|
||||||
except urllib3.exceptions.HTTPError as error:
|
|
||||||
# HTTPError must come last as its the base urllib3 exception class
|
|
||||||
# TODO: do something smart here; for now just raise NetworkError
|
|
||||||
raise NetworkError('urllib3 HTTPError {0}'.format(error))
|
|
||||||
|
|
||||||
if 200 <= resp.status <= 299:
|
kwargs = dict(
|
||||||
# 200-299 range are HTTP success statuses
|
maxsize=con_pool_size,
|
||||||
return resp.data
|
cert_reqs='CERT_REQUIRED',
|
||||||
|
ca_certs=certifi.where(),
|
||||||
|
socket_options=HTTPConnection.default_socket_options + [
|
||||||
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
||||||
|
])
|
||||||
|
|
||||||
try:
|
# Set a proxy according to the following order:
|
||||||
message = _parse(resp.data)
|
# * proxy defined in proxy_url (+ urllib3_proxy_kwargs)
|
||||||
except ValueError:
|
# * proxy set in `HTTPS_PROXY` env. var.
|
||||||
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
|
# * proxy set in `https_proxy` env. var.
|
||||||
|
# * None (if no proxy is configured)
|
||||||
|
|
||||||
if resp.status in (401, 403):
|
if not proxy_url:
|
||||||
raise Unauthorized()
|
proxy_url = os.environ.get('HTTPS_PROXY') or os.environ.get('https_proxy')
|
||||||
elif resp.status == 400:
|
|
||||||
raise BadRequest(repr(message))
|
|
||||||
elif resp.status == 502:
|
|
||||||
raise NetworkError('Bad Gateway')
|
|
||||||
else:
|
|
||||||
raise NetworkError('{0} ({1})'.format(message, resp.status))
|
|
||||||
|
|
||||||
|
if not proxy_url:
|
||||||
|
mgr = urllib3.PoolManager(**kwargs)
|
||||||
|
else:
|
||||||
|
kwargs.update(urllib3_proxy_kwargs)
|
||||||
|
mgr = urllib3.proxy_from_url(proxy_url, **kwargs)
|
||||||
|
if mgr.proxy.auth:
|
||||||
|
# TODO: what about other auth types?
|
||||||
|
auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth)
|
||||||
|
mgr.proxy_headers.update(auth_hdrs)
|
||||||
|
|
||||||
def get(url):
|
self._con_pool = mgr
|
||||||
"""Request an URL.
|
|
||||||
Args:
|
|
||||||
url:
|
|
||||||
The web location we want to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
def stop(self):
|
||||||
A JSON object.
|
self._con_pool.clear()
|
||||||
|
|
||||||
"""
|
@staticmethod
|
||||||
result = _request_wrapper('GET', url)
|
def _parse(json_data):
|
||||||
|
"""Try and parse the JSON returned from Telegram.
|
||||||
|
|
||||||
return _parse(result)
|
Returns:
|
||||||
|
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
|
||||||
|
|
||||||
|
"""
|
||||||
|
decoded_s = json_data.decode('utf-8')
|
||||||
|
try:
|
||||||
|
data = json.loads(decoded_s)
|
||||||
|
except ValueError:
|
||||||
|
raise TelegramError('Invalid server response')
|
||||||
|
|
||||||
def post(url, data, timeout=None):
|
if not data.get('ok'):
|
||||||
"""Request an URL.
|
description = data.get('description')
|
||||||
Args:
|
parameters = data.get('parameters')
|
||||||
url:
|
if parameters:
|
||||||
The web location we want to retrieve.
|
migrate_to_chat_id = parameters.get('migrate_to_chat_id')
|
||||||
data:
|
if migrate_to_chat_id:
|
||||||
A dict of (str, unicode) key/value pairs.
|
raise ChatMigrated(migrate_to_chat_id)
|
||||||
timeout:
|
if description:
|
||||||
float. If this value is specified, use it as the definitive timeout (in
|
return description
|
||||||
seconds) for urlopen() operations. [Optional]
|
|
||||||
|
|
||||||
Notes:
|
return data['result']
|
||||||
If neither `timeout` nor `data['timeout']` is specified. The underlying
|
|
||||||
defaults are used.
|
|
||||||
|
|
||||||
Returns:
|
def _request_wrapper(self, *args, **kwargs):
|
||||||
A JSON object.
|
"""Wraps urllib3 request for handling known exceptions.
|
||||||
|
|
||||||
"""
|
Args:
|
||||||
urlopen_kwargs = {}
|
args: unnamed arguments, passed to urllib3 request.
|
||||||
|
kwargs: keyword arguments, passed tp urllib3 request.
|
||||||
|
|
||||||
if timeout is not None:
|
Returns:
|
||||||
urlopen_kwargs['timeout'] = timeout
|
str: A non-parsed JSON text.
|
||||||
|
|
||||||
if InputFile.is_inputfile(data):
|
Raises:
|
||||||
data = InputFile(data)
|
TelegramError
|
||||||
result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
|
|
||||||
else:
|
|
||||||
data = json.dumps(data)
|
|
||||||
result = _request_wrapper(
|
|
||||||
'POST',
|
|
||||||
url,
|
|
||||||
body=data.encode(),
|
|
||||||
headers={'Content-Type': 'application/json'},
|
|
||||||
**urlopen_kwargs)
|
|
||||||
|
|
||||||
return _parse(result)
|
"""
|
||||||
|
try:
|
||||||
|
resp = self._con_pool.request(*args, **kwargs)
|
||||||
|
except urllib3.exceptions.TimeoutError:
|
||||||
|
raise TimedOut()
|
||||||
|
except urllib3.exceptions.HTTPError as error:
|
||||||
|
# HTTPError must come last as its the base urllib3 exception class
|
||||||
|
# TODO: do something smart here; for now just raise NetworkError
|
||||||
|
raise NetworkError('urllib3 HTTPError {0}'.format(error))
|
||||||
|
|
||||||
|
if 200 <= resp.status <= 299:
|
||||||
|
# 200-299 range are HTTP success statuses
|
||||||
|
return resp.data
|
||||||
|
|
||||||
def download(url, filename):
|
try:
|
||||||
"""Download a file by its URL.
|
message = self._parse(resp.data)
|
||||||
Args:
|
except ValueError:
|
||||||
url:
|
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
|
||||||
The web location we want to retrieve.
|
|
||||||
|
|
||||||
filename:
|
if resp.status in (401, 403):
|
||||||
The filename within the path to download the file.
|
raise Unauthorized()
|
||||||
|
elif resp.status == 400:
|
||||||
|
raise BadRequest(repr(message))
|
||||||
|
elif resp.status == 502:
|
||||||
|
raise NetworkError('Bad Gateway')
|
||||||
|
else:
|
||||||
|
raise NetworkError('{0} ({1})'.format(message, resp.status))
|
||||||
|
|
||||||
"""
|
def get(self, url):
|
||||||
buf = _request_wrapper('GET', url)
|
"""Request an URL.
|
||||||
with open(filename, 'wb') as fobj:
|
Args:
|
||||||
fobj.write(buf)
|
url:
|
||||||
|
The web location we want to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = self._request_wrapper('GET', url)
|
||||||
|
return self._parse(result)
|
||||||
|
|
||||||
|
def post(self, url, data, timeout=None):
|
||||||
|
"""Request an URL.
|
||||||
|
Args:
|
||||||
|
url:
|
||||||
|
The web location we want to retrieve.
|
||||||
|
data:
|
||||||
|
A dict of (str, unicode) key/value pairs.
|
||||||
|
timeout:
|
||||||
|
float. If this value is specified, use it as the definitive timeout (in
|
||||||
|
seconds) for urlopen() operations. [Optional]
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
If neither `timeout` nor `data['timeout']` is specified. The underlying
|
||||||
|
defaults are used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
urlopen_kwargs = {}
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
urlopen_kwargs['timeout'] = timeout
|
||||||
|
|
||||||
|
if InputFile.is_inputfile(data):
|
||||||
|
data = InputFile(data)
|
||||||
|
result = self._request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
|
||||||
|
else:
|
||||||
|
data = json.dumps(data)
|
||||||
|
result = self._request_wrapper(
|
||||||
|
'POST',
|
||||||
|
url,
|
||||||
|
body=data.encode(),
|
||||||
|
headers={'Content-Type': 'application/json'},
|
||||||
|
**urlopen_kwargs)
|
||||||
|
|
||||||
|
return self._parse(result)
|
||||||
|
|
||||||
|
def download(self, url, filename):
|
||||||
|
"""Download a file by its URL.
|
||||||
|
Args:
|
||||||
|
url:
|
||||||
|
The web location we want to retrieve.
|
||||||
|
|
||||||
|
filename:
|
||||||
|
The filename within the path to download the file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
buf = self._request_wrapper('GET', url)
|
||||||
|
with open(filename, 'wb') as fobj:
|
||||||
|
fobj.write(buf)
|
||||||
|
|
|
@ -198,7 +198,9 @@ class BotTest(BaseTest, unittest.TestCase):
|
||||||
def testInvalidSrvResp(self):
|
def testInvalidSrvResp(self):
|
||||||
with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'):
|
with self.assertRaisesRegexp(telegram.TelegramError, 'Invalid server response'):
|
||||||
# bypass the valid token check
|
# bypass the valid token check
|
||||||
bot = telegram.Bot.__new__(telegram.Bot)
|
newbot_cls = type(
|
||||||
|
'NoTokenValidateBot', (telegram.Bot,), dict(_validate_token=lambda x, y: None))
|
||||||
|
bot = newbot_cls('0xdeadbeef')
|
||||||
bot.base_url = 'https://api.telegram.org/bot{0}'.format('12')
|
bot.base_url = 'https://api.telegram.org/bot{0}'.format('12')
|
||||||
|
|
||||||
bot.getMe()
|
bot.getMe()
|
||||||
|
|
|
@ -36,7 +36,6 @@ except ImportError:
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
|
|
||||||
from telegram import Update, Message, TelegramError, User, Chat, Bot
|
from telegram import Update, Message, TelegramError, User, Chat, Bot
|
||||||
from telegram.utils.request import stop_con_pool
|
|
||||||
from telegram.ext import *
|
from telegram.ext import *
|
||||||
from tests.base import BaseTest
|
from tests.base import BaseTest
|
||||||
from tests.test_updater import MockBot
|
from tests.test_updater import MockBot
|
||||||
|
@ -61,10 +60,10 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
||||||
# At first we're thirsty. Then we brew coffee, we drink it
|
# At first we're thirsty. Then we brew coffee, we drink it
|
||||||
# and then we can start coding!
|
# and then we can start coding!
|
||||||
END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4)
|
END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4)
|
||||||
|
_updater = None
|
||||||
|
|
||||||
# Test related
|
# Test related
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.updater = None
|
|
||||||
self.current_state = dict()
|
self.current_state = dict()
|
||||||
self.entry_points = [CommandHandler('start', self.start)]
|
self.entry_points = [CommandHandler('start', self.start)]
|
||||||
self.states = {self.THIRSTY: [CommandHandler('brew', self.brew),
|
self.states = {self.THIRSTY: [CommandHandler('brew', self.brew),
|
||||||
|
@ -78,14 +77,22 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
||||||
self.fallbacks = [CommandHandler('eat', self.start)]
|
self.fallbacks = [CommandHandler('eat', self.start)]
|
||||||
|
|
||||||
def _setup_updater(self, *args, **kwargs):
|
def _setup_updater(self, *args, **kwargs):
|
||||||
stop_con_pool()
|
|
||||||
bot = MockBot(*args, **kwargs)
|
bot = MockBot(*args, **kwargs)
|
||||||
self.updater = Updater(workers=2, bot=bot)
|
self.updater = Updater(workers=2, bot=bot)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if self.updater is not None:
|
if self.updater is not None:
|
||||||
self.updater.stop()
|
self.updater.stop()
|
||||||
stop_con_pool()
|
|
||||||
|
@property
|
||||||
|
def updater(self):
|
||||||
|
return self._updater
|
||||||
|
|
||||||
|
@updater.setter
|
||||||
|
def updater(self, val):
|
||||||
|
if self._updater:
|
||||||
|
self._updater.stop()
|
||||||
|
self._updater = val
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.current_state = dict()
|
self.current_state = dict()
|
||||||
|
|
|
@ -101,19 +101,19 @@ class FileTest(BaseTest, unittest.TestCase):
|
||||||
self.assertTrue(os.path.isfile('telegram.ogg'))
|
self.assertTrue(os.path.isfile('telegram.ogg'))
|
||||||
|
|
||||||
def test_file_de_json(self):
|
def test_file_de_json(self):
|
||||||
newFile = telegram.File.de_json(self.json_dict)
|
newFile = telegram.File.de_json(self.json_dict, None)
|
||||||
|
|
||||||
self.assertEqual(newFile.file_id, self.json_dict['file_id'])
|
self.assertEqual(newFile.file_id, self.json_dict['file_id'])
|
||||||
self.assertEqual(newFile.file_path, self.json_dict['file_path'])
|
self.assertEqual(newFile.file_path, self.json_dict['file_path'])
|
||||||
self.assertEqual(newFile.file_size, self.json_dict['file_size'])
|
self.assertEqual(newFile.file_size, self.json_dict['file_size'])
|
||||||
|
|
||||||
def test_file_to_json(self):
|
def test_file_to_json(self):
|
||||||
newFile = telegram.File.de_json(self.json_dict)
|
newFile = telegram.File.de_json(self.json_dict, None)
|
||||||
|
|
||||||
self.assertTrue(self.is_json(newFile.to_json()))
|
self.assertTrue(self.is_json(newFile.to_json()))
|
||||||
|
|
||||||
def test_file_to_dict(self):
|
def test_file_to_dict(self):
|
||||||
newFile = telegram.File.de_json(self.json_dict)
|
newFile = telegram.File.de_json(self.json_dict, None)
|
||||||
|
|
||||||
self.assertTrue(self.is_dict(newFile.to_dict()))
|
self.assertTrue(self.is_dict(newFile.to_dict()))
|
||||||
self.assertEqual(newFile['file_id'], self.json_dict['file_id'])
|
self.assertEqual(newFile['file_id'], self.json_dict['file_id'])
|
||||||
|
|
|
@ -25,9 +25,10 @@ import sys
|
||||||
import unittest
|
import unittest
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
|
from tests.test_updater import MockBot
|
||||||
|
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
|
|
||||||
from telegram.utils.request import stop_con_pool
|
|
||||||
from telegram.ext import JobQueue, Job, Updater
|
from telegram.ext import JobQueue, Job, Updater
|
||||||
from tests.base import BaseTest
|
from tests.base import BaseTest
|
||||||
|
|
||||||
|
@ -49,13 +50,12 @@ class JobQueueTest(BaseTest, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.jq = JobQueue("Bot")
|
self.jq = JobQueue(MockBot('jobqueue_test'))
|
||||||
self.result = 0
|
self.result = 0
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if self.jq is not None:
|
if self.jq is not None:
|
||||||
self.jq.stop()
|
self.jq.stop()
|
||||||
stop_con_pool()
|
|
||||||
|
|
||||||
def job1(self, bot, job):
|
def job1(self, bot, job):
|
||||||
self.result += 1
|
self.result += 1
|
||||||
|
@ -158,12 +158,15 @@ class JobQueueTest(BaseTest, unittest.TestCase):
|
||||||
|
|
||||||
def test_inUpdater(self):
|
def test_inUpdater(self):
|
||||||
u = Updater(bot="MockBot")
|
u = Updater(bot="MockBot")
|
||||||
u.job_queue.put(Job(self.job1, 0.5))
|
try:
|
||||||
sleep(0.75)
|
u.job_queue.put(Job(self.job1, 0.5))
|
||||||
self.assertEqual(1, self.result)
|
sleep(0.75)
|
||||||
u.stop()
|
self.assertEqual(1, self.result)
|
||||||
sleep(2)
|
u.stop()
|
||||||
self.assertEqual(1, self.result)
|
sleep(2)
|
||||||
|
self.assertEqual(1, self.result)
|
||||||
|
finally:
|
||||||
|
u.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -29,10 +29,13 @@ import re
|
||||||
import unittest
|
import unittest
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from queue import Queue
|
||||||
from random import randrange
|
from random import randrange
|
||||||
|
|
||||||
from future.builtins import bytes
|
from future.builtins import bytes
|
||||||
|
|
||||||
|
from telegram.utils.request import Request as Requester
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# python2
|
# python2
|
||||||
from urllib2 import urlopen, Request, HTTPError
|
from urllib2 import urlopen, Request, HTTPError
|
||||||
|
@ -44,12 +47,11 @@ except ImportError:
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
|
|
||||||
from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery
|
from telegram import Update, Message, TelegramError, User, Chat, Bot, InlineQuery, CallbackQuery
|
||||||
from telegram.utils.request import stop_con_pool
|
|
||||||
from telegram.ext import *
|
from telegram.ext import *
|
||||||
from telegram.ext.dispatcher import run_async
|
from telegram.ext.dispatcher import run_async
|
||||||
from telegram.error import Unauthorized, InvalidToken
|
from telegram.error import Unauthorized, InvalidToken
|
||||||
from tests.base import BaseTest
|
from tests.base import BaseTest
|
||||||
from threading import Lock, Thread
|
from threading import Lock, Thread, current_thread, Semaphore
|
||||||
|
|
||||||
# Enable logging
|
# Enable logging
|
||||||
root = logging.getLogger()
|
root = logging.getLogger()
|
||||||
|
@ -68,10 +70,8 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
||||||
WebhookHandler
|
WebhookHandler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
updater = None
|
_updater = None
|
||||||
received_message = None
|
received_message = None
|
||||||
message_count = None
|
|
||||||
lock = None
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.updater = None
|
self.updater = None
|
||||||
|
@ -79,15 +79,25 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
||||||
self.message_count = 0
|
self.message_count = 0
|
||||||
self.lock = Lock()
|
self.lock = Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updater(self):
|
||||||
|
return self._updater
|
||||||
|
|
||||||
|
@updater.setter
|
||||||
|
def updater(self, val):
|
||||||
|
if self._updater:
|
||||||
|
self._updater.stop()
|
||||||
|
self._updater.dispatcher._reset_singleton()
|
||||||
|
del self._updater.dispatcher
|
||||||
|
|
||||||
|
self._updater = val
|
||||||
|
|
||||||
def _setup_updater(self, *args, **kwargs):
|
def _setup_updater(self, *args, **kwargs):
|
||||||
stop_con_pool()
|
|
||||||
bot = MockBot(*args, **kwargs)
|
bot = MockBot(*args, **kwargs)
|
||||||
self.updater = Updater(workers=2, bot=bot)
|
self.updater = Updater(workers=2, bot=bot)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if self.updater is not None:
|
self.updater = None
|
||||||
self.updater.stop()
|
|
||||||
stop_con_pool()
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.message_count = 0
|
self.message_count = 0
|
||||||
|
@ -411,6 +421,51 @@ class UpdaterTest(BaseTest, unittest.TestCase):
|
||||||
self.assertEqual(self.received_message, 'Test5')
|
self.assertEqual(self.received_message, 'Test5')
|
||||||
self.assertEqual(self.message_count, 2)
|
self.assertEqual(self.message_count, 2)
|
||||||
|
|
||||||
|
def test_multiple_dispatchers(self):
|
||||||
|
|
||||||
|
def get_dispatcher_name(q):
|
||||||
|
q.put(current_thread().name)
|
||||||
|
sleep(1.2)
|
||||||
|
|
||||||
|
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
|
||||||
|
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
|
||||||
|
q1 = Queue()
|
||||||
|
q2 = Queue()
|
||||||
|
|
||||||
|
try:
|
||||||
|
d1.run_async(get_dispatcher_name, q1)
|
||||||
|
d2.run_async(get_dispatcher_name, q2)
|
||||||
|
|
||||||
|
name1 = q1.get()
|
||||||
|
name2 = q2.get()
|
||||||
|
|
||||||
|
self.assertNotEqual(name1, name2)
|
||||||
|
finally:
|
||||||
|
d1.stop()
|
||||||
|
d2.stop()
|
||||||
|
# following three lines are for pypy unitests
|
||||||
|
d1._reset_singleton()
|
||||||
|
del d1
|
||||||
|
del d2
|
||||||
|
|
||||||
|
def test_multiple_dispatcers_no_decorator(self):
|
||||||
|
|
||||||
|
@run_async
|
||||||
|
def must_raise_runtime_error():
|
||||||
|
pass
|
||||||
|
|
||||||
|
d1 = Dispatcher(MockBot('disp1'), Queue(), workers=1)
|
||||||
|
d2 = Dispatcher(MockBot('disp2'), Queue(), workers=1)
|
||||||
|
|
||||||
|
self.assertRaises(RuntimeError, must_raise_runtime_error)
|
||||||
|
|
||||||
|
d1.stop()
|
||||||
|
d2.stop()
|
||||||
|
# following three lines are for pypy unitests
|
||||||
|
d1._reset_singleton()
|
||||||
|
del d1
|
||||||
|
del d2
|
||||||
|
|
||||||
def test_additionalArgs(self):
|
def test_additionalArgs(self):
|
||||||
self._setup_updater('', messages=0)
|
self._setup_updater('', messages=0)
|
||||||
handler = StringCommandHandler(
|
handler = StringCommandHandler(
|
||||||
|
|
Loading…
Reference in a new issue