updater: allow cleaning updates from Telegram servers before start

This commit is contained in:
Noam Meltzer 2016-03-01 21:40:04 +02:00
parent f0e7a3316c
commit a0a040a9c2
2 changed files with 31 additions and 4 deletions

View file

@ -94,7 +94,8 @@ class Updater:
self.__threads = [] self.__threads = []
""":type: list[Thread]""" """:type: list[Thread]"""
def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2): def start_polling(self, poll_interval=0.0, timeout=10, network_delay=2,
clean=False):
""" """
Starts polling updates from Telegram. Starts polling updates from Telegram.
@ -103,14 +104,19 @@ class Updater:
updates from Telegram in seconds. Default is 0.0 updates from Telegram in seconds. Default is 0.0
timeout (Optional[float]): Passed to Bot.getUpdates timeout (Optional[float]): Passed to Bot.getUpdates
network_delay (Optional[float]): Passed to Bot.getUpdates network_delay (Optional[float]): Passed to Bot.getUpdates
clean (Optional[bool]): Whether to clean any pending updates on
Telegram servers before actually starting to poll. Default is
False.
Returns: Returns:
Queue: The update queue that can be filled from the main thread Queue: The update queue that can be filled from the main thread
"""
"""
with self.__lock: with self.__lock:
if not self.running: if not self.running:
self.running = True self.running = True
if clean:
self._clean_updates()
# Create & start threads # Create & start threads
self._init_thread(self.dispatcher.start, "dispatcher") self._init_thread(self.dispatcher.start, "dispatcher")
@ -142,7 +148,8 @@ class Updater:
port=80, port=80,
url_path='', url_path='',
cert=None, cert=None,
key=None): key=None,
clean=False):
""" """
Starts a small http server to listen for updates via webhook. If cert Starts a small http server to listen for updates via webhook. If cert
and key are not provided, the webhook will be started directly on and key are not provided, the webhook will be started directly on
@ -156,6 +163,10 @@ class Updater:
url_path (Optional[str]): Path inside url url_path (Optional[str]): Path inside url
cert (Optional[str]): Path to the SSL certificate file cert (Optional[str]): Path to the SSL certificate file
key (Optional[str]): Path to the SSL key file key (Optional[str]): Path to the SSL key file
clean (Optional[bool]): Whether to clean any pending updates on
Telegram servers before actually starting the webhook. Default
is False.
Returns: Returns:
Queue: The update queue that can be filled from the main thread Queue: The update queue that can be filled from the main thread
@ -164,6 +175,8 @@ class Updater:
with self.__lock: with self.__lock:
if not self.running: if not self.running:
self.running = True self.running = True
if clean:
self._clean_updates()
# Create & start threads # Create & start threads
self._init_thread(self.dispatcher.start, "dispatcher"), self._init_thread(self.dispatcher.start, "dispatcher"),
@ -260,6 +273,12 @@ class Updater:
self.httpd.serve_forever(poll_interval=1) self.httpd.serve_forever(poll_interval=1)
def _clean_updates(self):
self.logger.info('Cleaning updates from Telegram server')
updates = self.bot.getUpdates()
while updates:
updates = self.bot.getUpdates(updates[-1].update_id + 1)
def stop(self): def stop(self):
""" """
Stops the polling/webhook thread, the dispatcher and the job queue Stops the polling/webhook thread, the dispatcher and the job queue

View file

@ -301,6 +301,15 @@ class UpdaterTest(BaseTest, unittest.TestCase):
sleep(.1) sleep(.1)
self.assertEqual(self.received_message, 'Test Error 1') self.assertEqual(self.received_message, 'Test Error 1')
def test_cleanBeforeStart(self):
self._setup_updater('')
d = self.updater.dispatcher
d.addTelegramMessageHandler(self.telegramHandlerTest)
self.updater.start_polling(0.01, clean=True)
sleep(.1)
self.assertEqual(self.message_count, 0)
self.assertIsNone(self.received_message)
def test_errorOnGetUpdates(self): def test_errorOnGetUpdates(self):
self._setup_updater('', raise_error=True) self._setup_updater('', raise_error=True)
d = self.updater.dispatcher d = self.updater.dispatcher
@ -396,7 +405,6 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertEqual(self.received_message, (('This', 'regex group'), self.assertEqual(self.received_message, (('This', 'regex group'),
{'testgroup': 'regex group'})) {'testgroup': 'regex group'}))
def test_runAsyncWithAdditionalArgs(self): def test_runAsyncWithAdditionalArgs(self):
self._setup_updater('Test6', messages=2) self._setup_updater('Test6', messages=2)
d = self.updater.dispatcher d = self.updater.dispatcher