Replace http.server with Tornado (#1191)

Fixes #1189
This commit is contained in:
Kirill Vasin 2018-09-08 23:25:48 +03:00 committed by Noam Meltzer
parent b2fb4264a3
commit f2b06728e9
5 changed files with 110 additions and 116 deletions

View file

@ -42,6 +42,7 @@ The following wonderful people contributed directly or indirectly to this projec
- `Joscha Götzer <https://github.com/Rostgnom>`_
- `jossalgon <https://github.com/jossalgon>`_
- `JRoot3D <https://github.com/JRoot3D>`_
- `Kirill Vasin <https://github.com/vasinkd>`_
- `Kjwon15 <https://github.com/kjwon15>`_
- `Li-aung Yip <https://github.com/LiaungYip>`_
- `macrojames <https://github.com/macrojames>`_

View file

@ -1,3 +1,4 @@
future>=0.16.0
certifi
tornado>=5.1
cryptography

View file

@ -19,11 +19,9 @@
"""This module contains the class Updater, which tries to make creating Telegram bots intuitive."""
import logging
import os
import ssl
from threading import Thread, Lock, current_thread, Event
from time import sleep
import subprocess
from signal import signal, SIGINT, SIGTERM, SIGABRT
from queue import Queue
@ -32,7 +30,7 @@ from telegram.ext import Dispatcher, JobQueue
from telegram.error import Unauthorized, InvalidToken, RetryAfter, TimedOut
from telegram.utils.helpers import get_signal_name
from telegram.utils.request import Request
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
from telegram.utils.webhookhandler import (WebhookServer, WebhookAppClass)
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -356,13 +354,24 @@ class Updater(object):
if not url_path.startswith('/'):
url_path = '/{0}'.format(url_path)
# Create Tornado app instance
app = WebhookAppClass(url_path, self.bot, self.update_queue)
# Form SSL Context
# An SSLError is raised if the private key does not match with the certificate
if use_ssl:
try:
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(cert, key)
except ssl.SSLError:
raise TelegramError('Invalid SSL Certificate')
else:
ssl_ctx = None
# Create and start server
self.httpd = WebhookServer((listen, port), WebhookHandler, self.update_queue, url_path,
self.bot)
self.httpd = WebhookServer(port, app, ssl_ctx)
if use_ssl:
self._check_ssl_cert(cert, key)
# DO NOT CHANGE: Only set webhook if SSL is handled by library
if not webhook_url:
webhook_url = self._gen_webhook_url(listen, port, url_path)
@ -377,26 +386,7 @@ class Updater(object):
self.logger.warning("cleaning updates is not supported if "
"SSL-termination happens elsewhere; skipping")
self.httpd.serve_forever(poll_interval=1)
def _check_ssl_cert(self, cert, key):
# Check SSL-Certificate with openssl, if possible
try:
exit_code = subprocess.call(
["openssl", "x509", "-text", "-noout", "-in", cert],
stdout=open(os.devnull, 'wb'),
stderr=subprocess.STDOUT)
except OSError:
exit_code = 0
if exit_code == 0:
try:
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket, certfile=cert, keyfile=key, server_side=True)
except ssl.SSLError as error:
self.logger.exception('Failed to init SSL socket')
raise TelegramError(str(error))
else:
raise TelegramError('SSL Certificate invalid')
self.httpd.serve_forever()
@staticmethod
def _gen_webhook_url(listen, port, url_path):

View file

@ -17,7 +17,6 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import logging
from telegram import Update
from future.utils import bytes_to_native_str
from threading import Lock
@ -25,39 +24,35 @@ try:
import ujson as json
except ImportError:
import json
try:
import BaseHTTPServer
except ImportError:
import http.server as BaseHTTPServer
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
import tornado.web
import tornado.iostream
logging.getLogger(__name__).addHandler(logging.NullHandler())
class _InvalidPost(Exception):
class WebhookServer(object):
def __init__(self, http_code):
self.http_code = http_code
super(_InvalidPost, self).__init__()
class WebhookServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, server_address, RequestHandlerClass, update_queue, webhook_path, bot):
super(WebhookServer, self).__init__(server_address, RequestHandlerClass)
def __init__(self, port, webhook_app, ssl_ctx):
self.http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
self.port = port
self.loop = None
self.logger = logging.getLogger(__name__)
self.update_queue = update_queue
self.webhook_path = webhook_path
self.bot = bot
self.is_running = False
self.server_lock = Lock()
self.shutdown_lock = Lock()
def serve_forever(self, poll_interval=0.5):
def serve_forever(self):
with self.server_lock:
IOLoop().make_current()
self.is_running = True
self.logger.debug('Webhook Server started.')
super(WebhookServer, self).serve_forever(poll_interval)
self.http_server.listen(self.port)
self.loop = IOLoop.current()
self.loop.start()
self.logger.debug('Webhook Server stopped.')
self.is_running = False
def shutdown(self):
with self.shutdown_lock:
@ -65,8 +60,7 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
self.logger.warning('Webhook Server already stopped.')
return
else:
super(WebhookServer, self).shutdown()
self.is_running = False
self.loop.add_callback(self.loop.stop)
def handle_error(self, request, client_address):
"""Handle an error gracefully."""
@ -74,64 +68,52 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
client_address, exc_info=True)
class WebhookAppClass(tornado.web.Application):
def __init__(self, webhook_path, bot, update_queue):
self.shared_objects = {"bot": bot, "update_queue": update_queue}
handlers = [
(r"{0}/?".format(webhook_path), WebhookHandler,
self.shared_objects)
] # noqa
tornado.web.Application.__init__(self, handlers)
def log_request(self, handler):
pass
# WebhookHandler, process webhook calls
# Based on: https://github.com/eternnoir/pyTelegramBotAPI/blob/master/
# examples/webhook_examples/webhook_cpython_echo_bot.py
class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
server_version = 'WebhookHandler/1.0'
class WebhookHandler(tornado.web.RequestHandler):
SUPPORTED_METHODS = ["POST"]
def __init__(self, request, client_address, server):
def __init__(self, application, request, **kwargs):
super(WebhookHandler, self).__init__(application, request, **kwargs)
self.logger = logging.getLogger(__name__)
super(WebhookHandler, self).__init__(request, client_address, server)
def do_HEAD(self):
self.send_response(200)
self.end_headers()
def initialize(self, bot, update_queue):
self.bot = bot
self.update_queue = update_queue
def do_GET(self):
self.send_response(200)
self.end_headers()
def set_default_headers(self):
self.set_header("Content-Type", 'application/json; charset="utf-8"')
def do_POST(self):
def post(self):
self.logger.debug('Webhook triggered')
try:
self._validate_post()
clen = self._get_content_len()
except _InvalidPost as e:
self.send_error(e.http_code)
self.end_headers()
else:
buf = self.rfile.read(clen)
json_string = bytes_to_native_str(buf)
self.send_response(200)
self.end_headers()
self.logger.debug('Webhook received data: ' + json_string)
update = Update.de_json(json.loads(json_string), self.server.bot)
self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
self.server.update_queue.put(update)
self._validate_post()
json_string = bytes_to_native_str(self.request.body)
data = json.loads(json_string)
self.set_status(200)
self.logger.debug('Webhook received data: ' + json_string)
update = Update.de_json(data, self.bot)
self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
self.update_queue.put(update)
def _validate_post(self):
if not (self.path == self.server.webhook_path and 'content-type' in self.headers and
self.headers['content-type'] == 'application/json'):
raise _InvalidPost(403)
ct_header = self.request.headers.get("Content-Type", None)
if ct_header != 'application/json':
raise tornado.web.HTTPError(403)
def _get_content_len(self):
clen = self.headers.get('content-length')
if clen is None:
raise _InvalidPost(411)
try:
clen = int(clen)
except ValueError:
raise _InvalidPost(403)
if clen < 0:
raise _InvalidPost(403)
return clen
def log_message(self, format, *args):
def write_error(self, status_code, **kwargs):
"""Log an arbitrary message.
This is used by all other logging functions.
@ -145,4 +127,6 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
The client ip is prefixed to every message.
"""
self.logger.debug("%s - - %s" % (self.address_string(), format % args))
super(WebhookHandler, self).write_error(status_code, **kwargs)
self.logger.debug("%s - - %s" % (self.request.remote_ip, "Exception in WebhookHandler"),
exc_info=kwargs['exc_info'])

View file

@ -150,14 +150,8 @@ class TestUpdater(object):
updater.start_webhook(
ip,
port,
url_path='TOKEN',
cert='./tests/test_updater.py',
key='./tests/test_updater.py', )
url_path='TOKEN')
sleep(.2)
# SSL-Wrapping will fail, so we start the server without SSL
thr = Thread(target=updater.httpd.serve_forever)
thr.start()
try:
# Now, we send an update to the server via urlopen
update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''),
@ -166,21 +160,44 @@ class TestUpdater(object):
sleep(.2)
assert q.get(False) == update
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py')
assert b'' == response.read()
assert 200 == response.code
# Returns 404 if path is incorrect
with pytest.raises(HTTPError) as excinfo:
self._send_webhook_msg(ip, port, None, 'webookhandler.py')
assert excinfo.value.code == 404
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py',
get_method=lambda: 'HEAD')
assert b'' == response.read()
assert 200 == response.code
with pytest.raises(HTTPError) as excinfo:
self._send_webhook_msg(ip, port, None, 'webookhandler.py',
get_method=lambda: 'HEAD')
assert excinfo.value.code == 404
# Test multiple shutdown() calls
updater.httpd.shutdown()
finally:
updater.httpd.shutdown()
thr.join()
sleep(.2)
assert not updater.httpd.is_running
updater.stop()
def test_webhook_ssl(self, monkeypatch, updater):
monkeypatch.setattr('telegram.Bot.set_webhook', lambda *args, **kwargs: True)
monkeypatch.setattr('telegram.Bot.delete_webhook', lambda *args, **kwargs: True)
ip = '127.0.0.1'
port = randrange(1024, 49152) # Select random port for travis
tg_err = False
try:
updater._start_webhook(
ip,
port,
url_path='TOKEN',
cert='./tests/test_updater.py',
key='./tests/test_updater.py',
bootstrap_retries=0,
clean=False,
webhook_url=None,
allowed_updates=None)
except TelegramError:
tg_err = True
assert tg_err
def test_webhook_no_ssl(self, monkeypatch, updater):
q = Queue()
@ -199,6 +216,7 @@ class TestUpdater(object):
self._send_webhook_msg(ip, port, update.to_json())
sleep(.2)
assert q.get(False) == update
updater.stop()
@pytest.mark.parametrize(('error',),
argvalues=[(TelegramError(''),)],
@ -254,7 +272,7 @@ class TestUpdater(object):
with pytest.raises(HTTPError) as excinfo:
self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2)
assert excinfo.value.code == 403
assert excinfo.value.code == 500
# TODO: prevent urllib or the underlying from adding content-length
# with pytest.raises(HTTPError) as excinfo:
@ -263,7 +281,7 @@ class TestUpdater(object):
with pytest.raises(HTTPError):
self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number')
assert excinfo.value.code == 403
assert excinfo.value.code == 500
finally:
updater.httpd.shutdown()