diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index 0b1965ce9..72b3dc30c 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -30,6 +30,7 @@ import subprocess from signal import signal, SIGINT, SIGTERM, SIGABRT from telegram import Bot, TelegramError, NullHandler from telegram.ext import dispatcher, Dispatcher, JobQueue, UpdateQueue +from telegram.error import Unauthorized, InvalidToken from telegram.utils.webhookhandler import (WebhookServer, WebhookHandler) logging.getLogger(__name__).addHandler(NullHandler()) @@ -94,7 +95,7 @@ class Updater: """:type: list[Thread]""" def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2, - clean=False): + clean=False, bootstrap_retries=0): """ Starts polling updates from Telegram. @@ -106,6 +107,11 @@ class Updater: clean (Optional[bool]): Whether to clean any pending updates on Telegram servers before actually starting to poll. Default is False. + bootstrap_retries (Optional[int[): Whether the bootstrapping phase + of the `Updater` will retry on failures on the Telegram server. + < 0 - retry indefinitely + 0 - no retries (default) + > 0 - retry up to X times Returns: Queue: The update queue that can be filled from the main thread @@ -120,7 +126,8 @@ class Updater: # Create & start threads self._init_thread(self.dispatcher.start, "dispatcher") self._init_thread(self._start_polling, "updater", - poll_interval, timeout, network_delay) + poll_interval, timeout, network_delay, + bootstrap_retries) # Return the update queue so the main thread can insert updates return self.update_queue @@ -148,7 +155,9 @@ class Updater: url_path='', cert=None, key=None, - clean=False): + clean=False, + bootstrap_retries=0, + webhook_url=None): """ Starts a small http server to listen for updates via webhook. If cert and key are not provided, the webhook will be started directly on @@ -165,12 +174,19 @@ class Updater: clean (Optional[bool]): Whether to clean any pending updates on Telegram servers before actually starting the webhook. Default is False. - + bootstrap_retries (Optional[int[): Whether the bootstrapping phase + of the `Updater` will retry on failures on the Telegram server. + < 0 - retry indefinitely + 0 - no retries (default) + > 0 - retry up to X times + webhook_url (Optional[str]): Explicitly specifiy the webhook url. + Useful behind NAT, reverse proxy, etc. Default is derived from + `listen`, `port` & `url_path`. Returns: Queue: The update queue that can be filled from the main thread - """ + """ with self.__lock: if not self.running: self.running = True @@ -180,23 +196,24 @@ class Updater: # Create & start threads self._init_thread(self.dispatcher.start, "dispatcher"), self._init_thread(self._start_webhook, "updater", listen, - port, url_path, cert, key) + port, url_path, cert, key, bootstrap_retries, + webhook_url) # Return the update queue so the main thread can insert updates return self.update_queue - def _start_polling(self, poll_interval, timeout, network_delay): + def _start_polling(self, poll_interval, timeout, network_delay, + bootstrap_retries): """ Thread target of thread 'updater'. Runs in background, pulls updates from Telegram and inserts them in the update queue of the Dispatcher. - """ + """ cur_interval = poll_interval self.logger.debug('Updater thread started') - # Remove webhook - self.bot.setWebhook(webhook_url=None) + self._set_webhook(None, bootstrap_retries) while self.running: try: @@ -228,6 +245,27 @@ class Updater: sleep(cur_interval) + def _set_webhook(self, webhook_url, max_retries): + retries = 0 + while 1: + try: + # Remove webhook + self.bot.setWebhook(webhook_url=webhook_url) + except (Unauthorized, InvalidToken): + raise + except TelegramError: + msg = 'failed to set webhook; try={0} max_retries={1}'.format( + retries, max_retries) + if max_retries < 0 or retries < max_retries: + self.logger.info(msg) + retries += 1 + else: + self.logger.exception(msg) + raise + else: + break + sleep(1) + @staticmethod def _increase_poll_interval(current_interval): # increase waiting times on subsequent errors up to 30secs @@ -239,7 +277,8 @@ class Updater: current_interval = 30 return current_interval - def _start_webhook(self, listen, port, url_path, cert, key): + def _start_webhook(self, listen, port, url_path, cert, key, + bootstrap_retries, webhook_url): self.logger.debug('Updater thread started') use_ssl = cert is not None and key is not None url_path = "/%s" % url_path @@ -247,6 +286,10 @@ class Updater: # Create and start server self.httpd = WebhookServer((listen, port), WebhookHandler, self.update_queue, url_path) + if not webhook_url: + webhook_url = self._gen_webhook_url(listen, port, url_path, + use_ssl) + self._set_webhook(webhook_url, bootstrap_retries) if use_ssl: # Check SSL-Certificate with openssl, if possible @@ -272,6 +315,13 @@ class Updater: self.httpd.serve_forever(poll_interval=1) + def _gen_webhook_url(self, listen, port, url_path, use_ssl): + return '{proto}://{listen}:{port}{path}'.format( + proto='https' if use_ssl else 'http', + listen=listen, + port=port, + path=url_path) + def _clean_updates(self): self.logger.debug('Cleaning updates from Telegram server') updates = self.bot.getUpdates() diff --git a/tests/test_updater.py b/tests/test_updater.py index 61022d67e..7883b545e 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -47,8 +47,10 @@ except ImportError: sys.path.append('.') -from telegram import Update, Message, TelegramError, User, Chat, Updater, Bot +from telegram import Update, Message, TelegramError, User, Chat, Bot +from telegram.ext.updater import Updater from telegram.ext.dispatcher import run_async +from telegram.error import Unauthorized, InvalidToken from tests.base import BaseTest from threading import Lock, Thread @@ -483,13 +485,46 @@ class UpdaterTest(BaseTest, unittest.TestCase): sleep(1) self.assertEqual(self.received_message, 'Webhook Test 2') + def test_bootstrap_retries_success(self): + retries = 3 + self._setup_updater('', messages=0, bootstrap_retries=retries) + + self.updater._set_webhook('path', retries) + self.assertEqual(self.updater.bot.bootstrap_attempts, retries) + + def test_bootstrap_retries_unauth(self): + retries = 3 + self._setup_updater('', messages=0, bootstrap_retries=retries, + bootstrap_err=Unauthorized()) + + self.assertRaises(Unauthorized, self.updater._set_webhook, 'path', + retries) + self.assertEqual(self.updater.bot.bootstrap_attempts, 1) + + def test_bootstrap_retries_invalid_token(self): + retries = 3 + self._setup_updater('', messages=0, bootstrap_retries=retries, + bootstrap_err=InvalidToken()) + + self.assertRaises(InvalidToken, self.updater._set_webhook, 'path', + retries) + self.assertEqual(self.updater.bot.bootstrap_attempts, 1) + + def test_bootstrap_retries_fail(self): + retries = 1 + self._setup_updater('', messages=0, bootstrap_retries=retries) + + self.assertRaisesRegexp(TelegramError, 'test', + self.updater._set_webhook, 'path', retries - 1) + self.assertEqual(self.updater.bot.bootstrap_attempts, 1) + def test_webhook_invalid_posts(self): self._setup_updater('', messages=0) ip = '127.0.0.1' port = randrange(1024, 49152) # select random port for travis thr = Thread(target=self.updater._start_webhook, - args=(ip, port, '', None, None)) + args=(ip, port, '', None, None, 0, None)) thr.start() sleep(0.5) @@ -578,12 +613,15 @@ class UpdaterTest(BaseTest, unittest.TestCase): class MockBot: - def __init__(self, text, messages=1, raise_error=False): + def __init__(self, text, messages=1, raise_error=False, + bootstrap_retries=None, bootstrap_err=TelegramError('test')): self.text = text self.send_messages = messages self.raise_error = raise_error self.token = "TOKEN" - pass + self.bootstrap_retries = bootstrap_retries + self.bootstrap_attempts = 0 + self.bootstrap_err = bootstrap_err @staticmethod def mockUpdate(text): @@ -594,7 +632,12 @@ class MockBot: return update def setWebhook(self, webhook_url=None, certificate=None): - pass + if self.bootstrap_retries is None: + return + + if self.bootstrap_attempts < self.bootstrap_retries: + self.bootstrap_attempts += 1 + raise self.bootstrap_err def getUpdates(self, offset=None,