Replace http.server with Tornado (#1191)

Fixes #1189
This commit is contained in:
Kirill Vasin 2018-09-08 23:25:48 +03:00 committed by Noam Meltzer
parent b2fb4264a3
commit f2b06728e9
5 changed files with 110 additions and 116 deletions

View file

@ -42,6 +42,7 @@ The following wonderful people contributed directly or indirectly to this projec
- `Joscha Götzer <https://github.com/Rostgnom>`_ - `Joscha Götzer <https://github.com/Rostgnom>`_
- `jossalgon <https://github.com/jossalgon>`_ - `jossalgon <https://github.com/jossalgon>`_
- `JRoot3D <https://github.com/JRoot3D>`_ - `JRoot3D <https://github.com/JRoot3D>`_
- `Kirill Vasin <https://github.com/vasinkd>`_
- `Kjwon15 <https://github.com/kjwon15>`_ - `Kjwon15 <https://github.com/kjwon15>`_
- `Li-aung Yip <https://github.com/LiaungYip>`_ - `Li-aung Yip <https://github.com/LiaungYip>`_
- `macrojames <https://github.com/macrojames>`_ - `macrojames <https://github.com/macrojames>`_

View file

@ -1,3 +1,4 @@
future>=0.16.0 future>=0.16.0
certifi certifi
tornado>=5.1
cryptography cryptography

View file

@ -19,11 +19,9 @@
"""This module contains the class Updater, which tries to make creating Telegram bots intuitive.""" """This module contains the class Updater, which tries to make creating Telegram bots intuitive."""
import logging import logging
import os
import ssl import ssl
from threading import Thread, Lock, current_thread, Event from threading import Thread, Lock, current_thread, Event
from time import sleep from time import sleep
import subprocess
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from queue import Queue from queue import Queue
@ -32,7 +30,7 @@ from telegram.ext import Dispatcher, JobQueue
from telegram.error import Unauthorized, InvalidToken, RetryAfter, TimedOut from telegram.error import Unauthorized, InvalidToken, RetryAfter, TimedOut
from telegram.utils.helpers import get_signal_name from telegram.utils.helpers import get_signal_name
from telegram.utils.request import Request from telegram.utils.request import Request
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler) from telegram.utils.webhookhandler import (WebhookServer, WebhookAppClass)
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -356,13 +354,24 @@ class Updater(object):
if not url_path.startswith('/'): if not url_path.startswith('/'):
url_path = '/{0}'.format(url_path) url_path = '/{0}'.format(url_path)
# Create Tornado app instance
app = WebhookAppClass(url_path, self.bot, self.update_queue)
# Form SSL Context
# An SSLError is raised if the private key does not match with the certificate
if use_ssl:
try:
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(cert, key)
except ssl.SSLError:
raise TelegramError('Invalid SSL Certificate')
else:
ssl_ctx = None
# Create and start server # Create and start server
self.httpd = WebhookServer((listen, port), WebhookHandler, self.update_queue, url_path, self.httpd = WebhookServer(port, app, ssl_ctx)
self.bot)
if use_ssl: if use_ssl:
self._check_ssl_cert(cert, key)
# DO NOT CHANGE: Only set webhook if SSL is handled by library # DO NOT CHANGE: Only set webhook if SSL is handled by library
if not webhook_url: if not webhook_url:
webhook_url = self._gen_webhook_url(listen, port, url_path) webhook_url = self._gen_webhook_url(listen, port, url_path)
@ -377,26 +386,7 @@ class Updater(object):
self.logger.warning("cleaning updates is not supported if " self.logger.warning("cleaning updates is not supported if "
"SSL-termination happens elsewhere; skipping") "SSL-termination happens elsewhere; skipping")
self.httpd.serve_forever(poll_interval=1) self.httpd.serve_forever()
def _check_ssl_cert(self, cert, key):
# Check SSL-Certificate with openssl, if possible
try:
exit_code = subprocess.call(
["openssl", "x509", "-text", "-noout", "-in", cert],
stdout=open(os.devnull, 'wb'),
stderr=subprocess.STDOUT)
except OSError:
exit_code = 0
if exit_code == 0:
try:
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket, certfile=cert, keyfile=key, server_side=True)
except ssl.SSLError as error:
self.logger.exception('Failed to init SSL socket')
raise TelegramError(str(error))
else:
raise TelegramError('SSL Certificate invalid')
@staticmethod @staticmethod
def _gen_webhook_url(listen, port, url_path): def _gen_webhook_url(listen, port, url_path):

View file

@ -17,7 +17,6 @@
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
import logging import logging
from telegram import Update from telegram import Update
from future.utils import bytes_to_native_str from future.utils import bytes_to_native_str
from threading import Lock from threading import Lock
@ -25,39 +24,35 @@ try:
import ujson as json import ujson as json
except ImportError: except ImportError:
import json import json
try: from tornado.httpserver import HTTPServer
import BaseHTTPServer from tornado.ioloop import IOLoop
except ImportError: import tornado.web
import http.server as BaseHTTPServer import tornado.iostream
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
class _InvalidPost(Exception): class WebhookServer(object):
def __init__(self, http_code): def __init__(self, port, webhook_app, ssl_ctx):
self.http_code = http_code self.http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
super(_InvalidPost, self).__init__() self.port = port
self.loop = None
class WebhookServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, server_address, RequestHandlerClass, update_queue, webhook_path, bot):
super(WebhookServer, self).__init__(server_address, RequestHandlerClass)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.update_queue = update_queue
self.webhook_path = webhook_path
self.bot = bot
self.is_running = False self.is_running = False
self.server_lock = Lock() self.server_lock = Lock()
self.shutdown_lock = Lock() self.shutdown_lock = Lock()
def serve_forever(self, poll_interval=0.5): def serve_forever(self):
with self.server_lock: with self.server_lock:
IOLoop().make_current()
self.is_running = True self.is_running = True
self.logger.debug('Webhook Server started.') self.logger.debug('Webhook Server started.')
super(WebhookServer, self).serve_forever(poll_interval) self.http_server.listen(self.port)
self.loop = IOLoop.current()
self.loop.start()
self.logger.debug('Webhook Server stopped.') self.logger.debug('Webhook Server stopped.')
self.is_running = False
def shutdown(self): def shutdown(self):
with self.shutdown_lock: with self.shutdown_lock:
@ -65,8 +60,7 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
self.logger.warning('Webhook Server already stopped.') self.logger.warning('Webhook Server already stopped.')
return return
else: else:
super(WebhookServer, self).shutdown() self.loop.add_callback(self.loop.stop)
self.is_running = False
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
"""Handle an error gracefully.""" """Handle an error gracefully."""
@ -74,64 +68,52 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
client_address, exc_info=True) client_address, exc_info=True)
class WebhookAppClass(tornado.web.Application):
def __init__(self, webhook_path, bot, update_queue):
self.shared_objects = {"bot": bot, "update_queue": update_queue}
handlers = [
(r"{0}/?".format(webhook_path), WebhookHandler,
self.shared_objects)
] # noqa
tornado.web.Application.__init__(self, handlers)
def log_request(self, handler):
pass
# WebhookHandler, process webhook calls # WebhookHandler, process webhook calls
# Based on: https://github.com/eternnoir/pyTelegramBotAPI/blob/master/ class WebhookHandler(tornado.web.RequestHandler):
# examples/webhook_examples/webhook_cpython_echo_bot.py SUPPORTED_METHODS = ["POST"]
class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
server_version = 'WebhookHandler/1.0'
def __init__(self, request, client_address, server): def __init__(self, application, request, **kwargs):
super(WebhookHandler, self).__init__(application, request, **kwargs)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
super(WebhookHandler, self).__init__(request, client_address, server)
def do_HEAD(self): def initialize(self, bot, update_queue):
self.send_response(200) self.bot = bot
self.end_headers() self.update_queue = update_queue
def do_GET(self): def set_default_headers(self):
self.send_response(200) self.set_header("Content-Type", 'application/json; charset="utf-8"')
self.end_headers()
def do_POST(self): def post(self):
self.logger.debug('Webhook triggered') self.logger.debug('Webhook triggered')
try: self._validate_post()
self._validate_post() json_string = bytes_to_native_str(self.request.body)
clen = self._get_content_len() data = json.loads(json_string)
except _InvalidPost as e: self.set_status(200)
self.send_error(e.http_code) self.logger.debug('Webhook received data: ' + json_string)
self.end_headers() update = Update.de_json(data, self.bot)
else: self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
buf = self.rfile.read(clen) self.update_queue.put(update)
json_string = bytes_to_native_str(buf)
self.send_response(200)
self.end_headers()
self.logger.debug('Webhook received data: ' + json_string)
update = Update.de_json(json.loads(json_string), self.server.bot)
self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
self.server.update_queue.put(update)
def _validate_post(self): def _validate_post(self):
if not (self.path == self.server.webhook_path and 'content-type' in self.headers and ct_header = self.request.headers.get("Content-Type", None)
self.headers['content-type'] == 'application/json'): if ct_header != 'application/json':
raise _InvalidPost(403) raise tornado.web.HTTPError(403)
def _get_content_len(self): def write_error(self, status_code, **kwargs):
clen = self.headers.get('content-length')
if clen is None:
raise _InvalidPost(411)
try:
clen = int(clen)
except ValueError:
raise _InvalidPost(403)
if clen < 0:
raise _InvalidPost(403)
return clen
def log_message(self, format, *args):
"""Log an arbitrary message. """Log an arbitrary message.
This is used by all other logging functions. This is used by all other logging functions.
@ -145,4 +127,6 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
The client ip is prefixed to every message. The client ip is prefixed to every message.
""" """
self.logger.debug("%s - - %s" % (self.address_string(), format % args)) super(WebhookHandler, self).write_error(status_code, **kwargs)
self.logger.debug("%s - - %s" % (self.request.remote_ip, "Exception in WebhookHandler"),
exc_info=kwargs['exc_info'])

View file

@ -150,14 +150,8 @@ class TestUpdater(object):
updater.start_webhook( updater.start_webhook(
ip, ip,
port, port,
url_path='TOKEN', url_path='TOKEN')
cert='./tests/test_updater.py',
key='./tests/test_updater.py', )
sleep(.2) sleep(.2)
# SSL-Wrapping will fail, so we start the server without SSL
thr = Thread(target=updater.httpd.serve_forever)
thr.start()
try: try:
# Now, we send an update to the server via urlopen # Now, we send an update to the server via urlopen
update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''),
@ -166,21 +160,44 @@ class TestUpdater(object):
sleep(.2) sleep(.2)
assert q.get(False) == update assert q.get(False) == update
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py') # Returns 404 if path is incorrect
assert b'' == response.read() with pytest.raises(HTTPError) as excinfo:
assert 200 == response.code self._send_webhook_msg(ip, port, None, 'webookhandler.py')
assert excinfo.value.code == 404
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py', with pytest.raises(HTTPError) as excinfo:
get_method=lambda: 'HEAD') self._send_webhook_msg(ip, port, None, 'webookhandler.py',
get_method=lambda: 'HEAD')
assert b'' == response.read() assert excinfo.value.code == 404
assert 200 == response.code
# Test multiple shutdown() calls # Test multiple shutdown() calls
updater.httpd.shutdown() updater.httpd.shutdown()
finally: finally:
updater.httpd.shutdown() updater.httpd.shutdown()
thr.join() sleep(.2)
assert not updater.httpd.is_running
updater.stop()
def test_webhook_ssl(self, monkeypatch, updater):
monkeypatch.setattr('telegram.Bot.set_webhook', lambda *args, **kwargs: True)
monkeypatch.setattr('telegram.Bot.delete_webhook', lambda *args, **kwargs: True)
ip = '127.0.0.1'
port = randrange(1024, 49152) # Select random port for travis
tg_err = False
try:
updater._start_webhook(
ip,
port,
url_path='TOKEN',
cert='./tests/test_updater.py',
key='./tests/test_updater.py',
bootstrap_retries=0,
clean=False,
webhook_url=None,
allowed_updates=None)
except TelegramError:
tg_err = True
assert tg_err
def test_webhook_no_ssl(self, monkeypatch, updater): def test_webhook_no_ssl(self, monkeypatch, updater):
q = Queue() q = Queue()
@ -199,6 +216,7 @@ class TestUpdater(object):
self._send_webhook_msg(ip, port, update.to_json()) self._send_webhook_msg(ip, port, update.to_json())
sleep(.2) sleep(.2)
assert q.get(False) == update assert q.get(False) == update
updater.stop()
@pytest.mark.parametrize(('error',), @pytest.mark.parametrize(('error',),
argvalues=[(TelegramError(''),)], argvalues=[(TelegramError(''),)],
@ -254,7 +272,7 @@ class TestUpdater(object):
with pytest.raises(HTTPError) as excinfo: with pytest.raises(HTTPError) as excinfo:
self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2) self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2)
assert excinfo.value.code == 403 assert excinfo.value.code == 500
# TODO: prevent urllib or the underlying from adding content-length # TODO: prevent urllib or the underlying from adding content-length
# with pytest.raises(HTTPError) as excinfo: # with pytest.raises(HTTPError) as excinfo:
@ -263,7 +281,7 @@ class TestUpdater(object):
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number') self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number')
assert excinfo.value.code == 403 assert excinfo.value.code == 500
finally: finally:
updater.httpd.shutdown() updater.httpd.shutdown()