mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-27 08:50:38 +01:00
parent
b2fb4264a3
commit
f2b06728e9
5 changed files with 110 additions and 116 deletions
|
@ -42,6 +42,7 @@ The following wonderful people contributed directly or indirectly to this projec
|
|||
- `Joscha Götzer <https://github.com/Rostgnom>`_
|
||||
- `jossalgon <https://github.com/jossalgon>`_
|
||||
- `JRoot3D <https://github.com/JRoot3D>`_
|
||||
- `Kirill Vasin <https://github.com/vasinkd>`_
|
||||
- `Kjwon15 <https://github.com/kjwon15>`_
|
||||
- `Li-aung Yip <https://github.com/LiaungYip>`_
|
||||
- `macrojames <https://github.com/macrojames>`_
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
future>=0.16.0
|
||||
certifi
|
||||
tornado>=5.1
|
||||
cryptography
|
||||
|
|
|
@ -19,11 +19,9 @@
|
|||
"""This module contains the class Updater, which tries to make creating Telegram bots intuitive."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from threading import Thread, Lock, current_thread, Event
|
||||
from time import sleep
|
||||
import subprocess
|
||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||
from queue import Queue
|
||||
|
||||
|
@ -32,7 +30,7 @@ from telegram.ext import Dispatcher, JobQueue
|
|||
from telegram.error import Unauthorized, InvalidToken, RetryAfter, TimedOut
|
||||
from telegram.utils.helpers import get_signal_name
|
||||
from telegram.utils.request import Request
|
||||
from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler)
|
||||
from telegram.utils.webhookhandler import (WebhookServer, WebhookAppClass)
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
|
@ -356,13 +354,24 @@ class Updater(object):
|
|||
if not url_path.startswith('/'):
|
||||
url_path = '/{0}'.format(url_path)
|
||||
|
||||
# Create Tornado app instance
|
||||
app = WebhookAppClass(url_path, self.bot, self.update_queue)
|
||||
|
||||
# Form SSL Context
|
||||
# An SSLError is raised if the private key does not match with the certificate
|
||||
if use_ssl:
|
||||
try:
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain(cert, key)
|
||||
except ssl.SSLError:
|
||||
raise TelegramError('Invalid SSL Certificate')
|
||||
else:
|
||||
ssl_ctx = None
|
||||
|
||||
# Create and start server
|
||||
self.httpd = WebhookServer((listen, port), WebhookHandler, self.update_queue, url_path,
|
||||
self.bot)
|
||||
self.httpd = WebhookServer(port, app, ssl_ctx)
|
||||
|
||||
if use_ssl:
|
||||
self._check_ssl_cert(cert, key)
|
||||
|
||||
# DO NOT CHANGE: Only set webhook if SSL is handled by library
|
||||
if not webhook_url:
|
||||
webhook_url = self._gen_webhook_url(listen, port, url_path)
|
||||
|
@ -377,26 +386,7 @@ class Updater(object):
|
|||
self.logger.warning("cleaning updates is not supported if "
|
||||
"SSL-termination happens elsewhere; skipping")
|
||||
|
||||
self.httpd.serve_forever(poll_interval=1)
|
||||
|
||||
def _check_ssl_cert(self, cert, key):
|
||||
# Check SSL-Certificate with openssl, if possible
|
||||
try:
|
||||
exit_code = subprocess.call(
|
||||
["openssl", "x509", "-text", "-noout", "-in", cert],
|
||||
stdout=open(os.devnull, 'wb'),
|
||||
stderr=subprocess.STDOUT)
|
||||
except OSError:
|
||||
exit_code = 0
|
||||
if exit_code == 0:
|
||||
try:
|
||||
self.httpd.socket = ssl.wrap_socket(
|
||||
self.httpd.socket, certfile=cert, keyfile=key, server_side=True)
|
||||
except ssl.SSLError as error:
|
||||
self.logger.exception('Failed to init SSL socket')
|
||||
raise TelegramError(str(error))
|
||||
else:
|
||||
raise TelegramError('SSL Certificate invalid')
|
||||
self.httpd.serve_forever()
|
||||
|
||||
@staticmethod
|
||||
def _gen_webhook_url(listen, port, url_path):
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
import logging
|
||||
|
||||
from telegram import Update
|
||||
from future.utils import bytes_to_native_str
|
||||
from threading import Lock
|
||||
|
@ -25,39 +24,35 @@ try:
|
|||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
try:
|
||||
import BaseHTTPServer
|
||||
except ImportError:
|
||||
import http.server as BaseHTTPServer
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
import tornado.web
|
||||
import tornado.iostream
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
class _InvalidPost(Exception):
|
||||
class WebhookServer(object):
|
||||
|
||||
def __init__(self, http_code):
|
||||
self.http_code = http_code
|
||||
super(_InvalidPost, self).__init__()
|
||||
|
||||
|
||||
class WebhookServer(BaseHTTPServer.HTTPServer, object):
|
||||
|
||||
def __init__(self, server_address, RequestHandlerClass, update_queue, webhook_path, bot):
|
||||
super(WebhookServer, self).__init__(server_address, RequestHandlerClass)
|
||||
def __init__(self, port, webhook_app, ssl_ctx):
|
||||
self.http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
|
||||
self.port = port
|
||||
self.loop = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.update_queue = update_queue
|
||||
self.webhook_path = webhook_path
|
||||
self.bot = bot
|
||||
self.is_running = False
|
||||
self.server_lock = Lock()
|
||||
self.shutdown_lock = Lock()
|
||||
|
||||
def serve_forever(self, poll_interval=0.5):
|
||||
def serve_forever(self):
|
||||
with self.server_lock:
|
||||
IOLoop().make_current()
|
||||
self.is_running = True
|
||||
self.logger.debug('Webhook Server started.')
|
||||
super(WebhookServer, self).serve_forever(poll_interval)
|
||||
self.http_server.listen(self.port)
|
||||
self.loop = IOLoop.current()
|
||||
self.loop.start()
|
||||
self.logger.debug('Webhook Server stopped.')
|
||||
self.is_running = False
|
||||
|
||||
def shutdown(self):
|
||||
with self.shutdown_lock:
|
||||
|
@ -65,8 +60,7 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
|
|||
self.logger.warning('Webhook Server already stopped.')
|
||||
return
|
||||
else:
|
||||
super(WebhookServer, self).shutdown()
|
||||
self.is_running = False
|
||||
self.loop.add_callback(self.loop.stop)
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
"""Handle an error gracefully."""
|
||||
|
@ -74,64 +68,52 @@ class WebhookServer(BaseHTTPServer.HTTPServer, object):
|
|||
client_address, exc_info=True)
|
||||
|
||||
|
||||
class WebhookAppClass(tornado.web.Application):
|
||||
|
||||
def __init__(self, webhook_path, bot, update_queue):
|
||||
self.shared_objects = {"bot": bot, "update_queue": update_queue}
|
||||
handlers = [
|
||||
(r"{0}/?".format(webhook_path), WebhookHandler,
|
||||
self.shared_objects)
|
||||
] # noqa
|
||||
tornado.web.Application.__init__(self, handlers)
|
||||
|
||||
def log_request(self, handler):
|
||||
pass
|
||||
|
||||
|
||||
# WebhookHandler, process webhook calls
|
||||
# Based on: https://github.com/eternnoir/pyTelegramBotAPI/blob/master/
|
||||
# examples/webhook_examples/webhook_cpython_echo_bot.py
|
||||
class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
|
||||
server_version = 'WebhookHandler/1.0'
|
||||
class WebhookHandler(tornado.web.RequestHandler):
|
||||
SUPPORTED_METHODS = ["POST"]
|
||||
|
||||
def __init__(self, request, client_address, server):
|
||||
def __init__(self, application, request, **kwargs):
|
||||
super(WebhookHandler, self).__init__(application, request, **kwargs)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
super(WebhookHandler, self).__init__(request, client_address, server)
|
||||
|
||||
def do_HEAD(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
def initialize(self, bot, update_queue):
|
||||
self.bot = bot
|
||||
self.update_queue = update_queue
|
||||
|
||||
def do_GET(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
def set_default_headers(self):
|
||||
self.set_header("Content-Type", 'application/json; charset="utf-8"')
|
||||
|
||||
def do_POST(self):
|
||||
def post(self):
|
||||
self.logger.debug('Webhook triggered')
|
||||
try:
|
||||
self._validate_post()
|
||||
clen = self._get_content_len()
|
||||
except _InvalidPost as e:
|
||||
self.send_error(e.http_code)
|
||||
self.end_headers()
|
||||
else:
|
||||
buf = self.rfile.read(clen)
|
||||
json_string = bytes_to_native_str(buf)
|
||||
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
||||
self.logger.debug('Webhook received data: ' + json_string)
|
||||
|
||||
update = Update.de_json(json.loads(json_string), self.server.bot)
|
||||
|
||||
self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
|
||||
self.server.update_queue.put(update)
|
||||
self._validate_post()
|
||||
json_string = bytes_to_native_str(self.request.body)
|
||||
data = json.loads(json_string)
|
||||
self.set_status(200)
|
||||
self.logger.debug('Webhook received data: ' + json_string)
|
||||
update = Update.de_json(data, self.bot)
|
||||
self.logger.debug('Received Update with ID %d on Webhook' % update.update_id)
|
||||
self.update_queue.put(update)
|
||||
|
||||
def _validate_post(self):
|
||||
if not (self.path == self.server.webhook_path and 'content-type' in self.headers and
|
||||
self.headers['content-type'] == 'application/json'):
|
||||
raise _InvalidPost(403)
|
||||
ct_header = self.request.headers.get("Content-Type", None)
|
||||
if ct_header != 'application/json':
|
||||
raise tornado.web.HTTPError(403)
|
||||
|
||||
def _get_content_len(self):
|
||||
clen = self.headers.get('content-length')
|
||||
if clen is None:
|
||||
raise _InvalidPost(411)
|
||||
try:
|
||||
clen = int(clen)
|
||||
except ValueError:
|
||||
raise _InvalidPost(403)
|
||||
if clen < 0:
|
||||
raise _InvalidPost(403)
|
||||
return clen
|
||||
|
||||
def log_message(self, format, *args):
|
||||
def write_error(self, status_code, **kwargs):
|
||||
"""Log an arbitrary message.
|
||||
|
||||
This is used by all other logging functions.
|
||||
|
@ -145,4 +127,6 @@ class WebhookHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
|
|||
The client ip is prefixed to every message.
|
||||
|
||||
"""
|
||||
self.logger.debug("%s - - %s" % (self.address_string(), format % args))
|
||||
super(WebhookHandler, self).write_error(status_code, **kwargs)
|
||||
self.logger.debug("%s - - %s" % (self.request.remote_ip, "Exception in WebhookHandler"),
|
||||
exc_info=kwargs['exc_info'])
|
||||
|
|
|
@ -150,14 +150,8 @@ class TestUpdater(object):
|
|||
updater.start_webhook(
|
||||
ip,
|
||||
port,
|
||||
url_path='TOKEN',
|
||||
cert='./tests/test_updater.py',
|
||||
key='./tests/test_updater.py', )
|
||||
url_path='TOKEN')
|
||||
sleep(.2)
|
||||
# SSL-Wrapping will fail, so we start the server without SSL
|
||||
thr = Thread(target=updater.httpd.serve_forever)
|
||||
thr.start()
|
||||
|
||||
try:
|
||||
# Now, we send an update to the server via urlopen
|
||||
update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''),
|
||||
|
@ -166,21 +160,44 @@ class TestUpdater(object):
|
|||
sleep(.2)
|
||||
assert q.get(False) == update
|
||||
|
||||
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py')
|
||||
assert b'' == response.read()
|
||||
assert 200 == response.code
|
||||
# Returns 404 if path is incorrect
|
||||
with pytest.raises(HTTPError) as excinfo:
|
||||
self._send_webhook_msg(ip, port, None, 'webookhandler.py')
|
||||
assert excinfo.value.code == 404
|
||||
|
||||
response = self._send_webhook_msg(ip, port, None, 'webookhandler.py',
|
||||
get_method=lambda: 'HEAD')
|
||||
|
||||
assert b'' == response.read()
|
||||
assert 200 == response.code
|
||||
with pytest.raises(HTTPError) as excinfo:
|
||||
self._send_webhook_msg(ip, port, None, 'webookhandler.py',
|
||||
get_method=lambda: 'HEAD')
|
||||
assert excinfo.value.code == 404
|
||||
|
||||
# Test multiple shutdown() calls
|
||||
updater.httpd.shutdown()
|
||||
finally:
|
||||
updater.httpd.shutdown()
|
||||
thr.join()
|
||||
sleep(.2)
|
||||
assert not updater.httpd.is_running
|
||||
updater.stop()
|
||||
|
||||
def test_webhook_ssl(self, monkeypatch, updater):
|
||||
monkeypatch.setattr('telegram.Bot.set_webhook', lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr('telegram.Bot.delete_webhook', lambda *args, **kwargs: True)
|
||||
ip = '127.0.0.1'
|
||||
port = randrange(1024, 49152) # Select random port for travis
|
||||
tg_err = False
|
||||
try:
|
||||
updater._start_webhook(
|
||||
ip,
|
||||
port,
|
||||
url_path='TOKEN',
|
||||
cert='./tests/test_updater.py',
|
||||
key='./tests/test_updater.py',
|
||||
bootstrap_retries=0,
|
||||
clean=False,
|
||||
webhook_url=None,
|
||||
allowed_updates=None)
|
||||
except TelegramError:
|
||||
tg_err = True
|
||||
assert tg_err
|
||||
|
||||
def test_webhook_no_ssl(self, monkeypatch, updater):
|
||||
q = Queue()
|
||||
|
@ -199,6 +216,7 @@ class TestUpdater(object):
|
|||
self._send_webhook_msg(ip, port, update.to_json())
|
||||
sleep(.2)
|
||||
assert q.get(False) == update
|
||||
updater.stop()
|
||||
|
||||
@pytest.mark.parametrize(('error',),
|
||||
argvalues=[(TelegramError(''),)],
|
||||
|
@ -254,7 +272,7 @@ class TestUpdater(object):
|
|||
|
||||
with pytest.raises(HTTPError) as excinfo:
|
||||
self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2)
|
||||
assert excinfo.value.code == 403
|
||||
assert excinfo.value.code == 500
|
||||
|
||||
# TODO: prevent urllib or the underlying from adding content-length
|
||||
# with pytest.raises(HTTPError) as excinfo:
|
||||
|
@ -263,7 +281,7 @@ class TestUpdater(object):
|
|||
|
||||
with pytest.raises(HTTPError):
|
||||
self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number')
|
||||
assert excinfo.value.code == 403
|
||||
assert excinfo.value.code == 500
|
||||
|
||||
finally:
|
||||
updater.httpd.shutdown()
|
||||
|
|
Loading…
Add table
Reference in a new issue