Merge branch 'master' into jobqueue-rework

Conflicts:
	tests/test_jobqueue.py
This commit is contained in:
Jannes Höke 2016-06-20 05:32:15 +02:00
commit c4a8ee5175
7 changed files with 162 additions and 111 deletions

View file

@ -1 +1,3 @@
future
future>=0.15.2
urllib3>=1.8.3
certifi

View file

@ -20,24 +20,47 @@
import logging
from functools import wraps
from threading import Thread, BoundedSemaphore, Lock, Event, current_thread
from threading import Thread, Lock, Event, current_thread
from time import sleep
from queue import Queue, Empty
from queue import Empty
from future.builtins import range
from telegram import (TelegramError, NullHandler)
from telegram.utils import request
from telegram.ext.handler import Handler
from telegram.utils.deprecate import deprecate
logging.getLogger(__name__).addHandler(NullHandler())
semaphore = None
async_threads = set()
ASYNC_QUEUE = Queue()
ASYNC_THREADS = set()
""":type: set[Thread]"""
async_lock = Lock()
ASYNC_LOCK = Lock() # guards ASYNC_THREADS
DEFAULT_GROUP = 0
def _pooled():
"""
A wrapper to run a thread in a thread pool
"""
while 1:
try:
func, args, kwargs = ASYNC_QUEUE.get()
# If unpacking fails, the thread pool is being closed from Updater._join_async_threads
except TypeError:
logging.getLogger(__name__).debug("Closing run_async thread %s/%d" %
(current_thread().getName(), len(ASYNC_THREADS)))
break
try:
func(*args, **kwargs)
except:
logging.getLogger(__name__).exception("run_async function raised exception")
def run_async(func):
"""
Function decorator that will run the function in a new thread.
@ -53,30 +76,11 @@ def run_async(func):
# set a threading.Event to notify caller thread
@wraps(func)
def pooled(*pargs, **kwargs):
"""
A wrapper to run a thread in a thread pool
"""
try:
result = func(*pargs, **kwargs)
finally:
semaphore.release()
with async_lock:
async_threads.remove(current_thread())
return result
@wraps(func)
def async_func(*pargs, **kwargs):
def async_func(*args, **kwargs):
"""
A wrapper to run a function in a thread
"""
thread = Thread(target=pooled, args=pargs, kwargs=kwargs)
semaphore.acquire()
with async_lock:
async_threads.add(thread)
thread.start()
return thread
ASYNC_QUEUE.put((func, args, kwargs))
return async_func
@ -112,11 +116,18 @@ class Dispatcher(object):
self.__stop_event = Event()
self.__exception_event = exception_event or Event()
global semaphore
if not semaphore:
semaphore = BoundedSemaphore(value=workers)
else:
self.logger.debug('Semaphore already initialized, skipping.')
with ASYNC_LOCK:
if not ASYNC_THREADS:
if request.is_con_pool_initialized():
raise RuntimeError('Connection Pool already initialized')
request.CON_POOL_SIZE = workers + 3
for i in range(workers):
thread = Thread(target=_pooled, name=str(i))
ASYNC_THREADS.add(thread)
thread.start()
else:
self.logger.debug('Thread pool already initialized, skipping.')
def start(self):
"""
@ -136,7 +147,7 @@ class Dispatcher(object):
self.running = True
self.logger.debug('Dispatcher started')
while True:
while 1:
try:
# Pop update from update queue.
update = self.update_queue.get(True, 1)
@ -150,7 +161,7 @@ class Dispatcher(object):
continue
self.logger.debug('Processing Update: %s' % update)
self.processUpdate(update)
self.process_update(update)
self.running = False
self.logger.debug('Dispatcher thread stopped')
@ -165,7 +176,7 @@ class Dispatcher(object):
sleep(0.1)
self.__stop_event.clear()
def processUpdate(self, update):
def process_update(self, update):
"""
Processes a single update.
@ -175,7 +186,7 @@ class Dispatcher(object):
# An error happened while polling
if isinstance(update, TelegramError):
self.dispatchError(None, update)
self.dispatch_error(None, update)
else:
for group in self.groups:
@ -190,7 +201,7 @@ class Dispatcher(object):
'Update.')
try:
self.dispatchError(update, te)
self.dispatch_error(update, te)
except Exception:
self.logger.exception('An uncaught error was raised while '
'handling the error')
@ -276,7 +287,7 @@ class Dispatcher(object):
if callback in self.error_handlers:
self.error_handlers.remove(callback)
def dispatchError(self, update, error):
def dispatch_error(self, update, error):
"""
Dispatches an error.

View file

@ -308,7 +308,7 @@ class Updater(object):
def _bootstrap(self, max_retries, clean, webhook_url, cert=None):
retries = 0
while True:
while 1:
try:
if clean:
@ -345,7 +345,7 @@ class Updater(object):
self.job_queue.stop()
with self.__lock:
if self.running:
if self.running or dispatcher.ASYNC_THREADS:
self.logger.debug('Stopping Updater and Dispatcher...')
self.running = False
@ -353,9 +353,8 @@ class Updater(object):
self._stop_httpd()
self._stop_dispatcher()
self._join_threads()
# async threads must be join()ed only after the dispatcher
# thread was joined, otherwise we can still have new async
# threads dispatched
# async threads must be join()ed only after the dispatcher thread was joined,
# otherwise we can still have new async threads dispatched
self._join_async_threads()
def _stop_httpd(self):
@ -371,13 +370,19 @@ class Updater(object):
self.dispatcher.stop()
def _join_async_threads(self):
with dispatcher.async_lock:
threads = list(dispatcher.async_threads)
total = len(threads)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i, total))
thr.join()
self.logger.debug('async thread {0}/{1} has ended'.format(i, total))
with dispatcher.ASYNC_LOCK:
threads = list(dispatcher.ASYNC_THREADS)
total = len(threads)
# Stop all threads in the thread pool by put()ting one non-tuple per thread
for i in range(total):
dispatcher.ASYNC_QUEUE.put(None)
for i, thr in enumerate(threads):
self.logger.debug('Waiting for async thread {0}/{1} to end'.format(i + 1, total))
thr.join()
dispatcher.ASYNC_THREADS.remove(thr)
self.logger.debug('async thread {0}/{1} has ended'.format(i + 1, total))
def _join_threads(self):
for thr in self.__threads:

View file

@ -18,29 +18,56 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains methods to make POST and GET requests"""
import functools
import json
import socket
from ssl import SSLError
import logging
from future.moves.http.client import HTTPException
from future.moves.urllib.error import HTTPError, URLError
from future.moves.urllib.request import urlopen, urlretrieve, Request
import certifi
import urllib3
from urllib3.connection import HTTPConnection
from telegram import (InputFile, TelegramError)
from telegram.error import Unauthorized, NetworkError, TimedOut, BadRequest
_CON_POOL = None
""":type: urllib3.PoolManager"""
CON_POOL_SIZE = 1
logging.getLogger('urllib3').setLevel(logging.WARNING)
def _get_con_pool():
global _CON_POOL
if _CON_POOL is not None:
return _CON_POOL
_CON_POOL = urllib3.PoolManager(maxsize=CON_POOL_SIZE,
cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where(),
socket_options=HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
])
return _CON_POOL
def is_con_pool_initialized():
return _CON_POOL is not None
def stop_con_pool():
global _CON_POOL
if _CON_POOL is not None:
_CON_POOL.clear()
_CON_POOL = None
def _parse(json_data):
"""Try and parse the JSON returned from Telegram and return an empty
dictionary if there is any error.
Args:
url:
urllib.urlopen object
"""Try and parse the JSON returned from Telegram.
Returns:
A JSON parsed as Python dict with results.
dict: A JSON parsed as Python dict with results - on error this dict will be empty.
"""
decoded_s = json_data.decode('utf-8')
try:
@ -54,53 +81,49 @@ def _parse(json_data):
return data['result']
def _try_except_req(func):
"""Decorator for requests to handle known exceptions"""
def _request_wrapper(*args, **kwargs):
"""Wraps urllib3 request for handling known exceptions.
@functools.wraps(func)
def decorator(*args, **kwargs):
try:
return func(*args, **kwargs)
Args:
args: unnamed arguments, passed to urllib3 request.
kwargs: keyword arguments, passed tp urllib3 request.
except HTTPError as error:
# `HTTPError` inherits from `URLError` so `HTTPError` handling must
# come first.
errcode = error.getcode()
Returns:
str: A non-parsed JSON text.
try:
message = _parse(error.read())
Raises:
TelegramError
if errcode in (401, 403):
raise Unauthorized()
elif errcode == 400:
raise BadRequest(message)
elif errcode == 502:
raise NetworkError('Bad Gateway')
except ValueError:
message = 'Unknown HTTPError {0}'.format(error.getcode())
"""
raise NetworkError('{0} ({1})'.format(message, errcode))
try:
resp = _get_con_pool().request(*args, **kwargs)
except urllib3.exceptions.TimeoutError as error:
raise TimedOut()
except urllib3.exceptions.HTTPError as error:
# HTTPError must come last as its the base urllib3 exception class
# TODO: do something smart here; for now just raise NetworkError
raise NetworkError('urllib3 HTTPError {0}'.format(error))
except URLError as error:
raise NetworkError('URLError: {0}'.format(error.reason))
if 200 <= resp.status <= 299:
# 200-299 range are HTTP success statuses
return resp.data
except (SSLError, socket.timeout) as error:
err_s = str(error)
if 'operation timed out' in err_s:
raise TimedOut()
try:
message = _parse(resp.data)
except ValueError:
raise NetworkError('Unknown HTTPError {0}'.format(resp.status))
raise NetworkError(err_s)
except HTTPException as error:
raise NetworkError('HTTPException: {0!r}'.format(error))
except socket.error as error:
raise NetworkError('socket.error: {0!r}'.format(error))
return decorator
if resp.status in (401, 403):
raise Unauthorized()
elif resp.status == 400:
raise BadRequest(repr(message))
elif resp.status == 502:
raise NetworkError('Bad Gateway')
else:
raise NetworkError('{0} ({1})'.format(message, resp.status))
@_try_except_req
def get(url):
"""Request an URL.
Args:
@ -109,13 +132,13 @@ def get(url):
Returns:
A JSON object.
"""
result = urlopen(url).read()
result = _request_wrapper('GET', url)
return _parse(result)
@_try_except_req
def post(url, data, timeout=None):
"""Request an URL.
Args:
@ -142,16 +165,17 @@ def post(url, data, timeout=None):
if InputFile.is_inputfile(data):
data = InputFile(data)
request = Request(url, data=data.to_form(), headers=data.headers)
result = _request_wrapper('POST', url, body=data.to_form(), headers=data.headers)
else:
data = json.dumps(data)
request = Request(url, data=data.encode(), headers={'Content-Type': 'application/json'})
result = _request_wrapper('POST',
url,
body=data.encode(),
headers={'Content-Type': 'application/json'})
result = urlopen(request, **urlopen_kwargs).read()
return _parse(result)
@_try_except_req
def download(url, filename):
"""Download a file by its URL.
Args:
@ -160,6 +184,8 @@ def download(url, filename):
filename:
The filename within the path to download the file.
"""
urlretrieve(url, filename)
"""
buf = _request_wrapper('GET', url)
with open(filename, 'wb') as fobj:
fobj.write(buf)

View file

@ -20,6 +20,7 @@
"""This module contains a object that represents Tests for Telegram Bot"""
import io
import re
from datetime import datetime
import sys
@ -211,10 +212,11 @@ class BotTest(BaseTest, unittest.TestCase):
@flaky(3, 1)
@timeout(10)
def testLeaveChat(self):
with self.assertRaisesRegexp(telegram.error.BadRequest, 'Chat not found'):
regex = re.compile('chat not found', re.IGNORECASE)
with self.assertRaisesRegexp(telegram.error.BadRequest, regex):
chat = self._bot.leaveChat(-123456)
with self.assertRaisesRegexp(telegram.error.NetworkError, 'Chat not found'):
with self.assertRaisesRegexp(telegram.error.NetworkError, regex):
chat = self._bot.leaveChat(-123456)
@flaky(3, 1)

View file

@ -31,6 +31,7 @@ else:
sys.path.append('.')
from telegram.utils.request import stop_con_pool
from telegram.ext import JobQueue, Job, Updater
from tests.base import BaseTest
@ -58,6 +59,7 @@ class JobQueueTest(BaseTest, unittest.TestCase):
def tearDown(self):
if self.jq is not None:
self.jq.stop()
stop_con_pool()
def job1(self, bot, job):
self.result += 1

View file

@ -48,6 +48,7 @@ except ImportError:
sys.path.append('.')
from telegram import Update, Message, TelegramError, User, Chat, Bot
from telegram.utils.request import stop_con_pool
from telegram.ext import *
from telegram.ext.dispatcher import run_async
from telegram.error import Unauthorized, InvalidToken
@ -83,12 +84,14 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.lock = Lock()
def _setup_updater(self, *args, **kwargs):
stop_con_pool()
bot = MockBot(*args, **kwargs)
self.updater = Updater(workers=2, bot=bot)
def tearDown(self):
if self.updater is not None:
self.updater.stop()
stop_con_pool()
def reset(self):
self.message_count = 0
@ -648,8 +651,8 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertFalse(self.updater.running)
def test_createBot(self):
updater = Updater('123:abcd')
self.assertIsNotNone(updater.bot)
self.updater = Updater('123:abcd')
self.assertIsNotNone(self.updater.bot)
def test_mutualExclusiveTokenBot(self):
bot = Bot('123:zyxw')