From 3d8771bbdf1b98b192509cce26101e15e5627e40 Mon Sep 17 00:00:00 2001 From: Lorenzo Rossi Date: Thu, 17 Oct 2019 00:03:53 +0200 Subject: [PATCH] :twisted_rightwards_arrows: Add mutex protection on ConversationHandler (#1533) * Add mutex protection on ConversationHandler * Remove timeout job before child update * Make locks private * Add conversation timeout conflict test --- telegram/ext/conversationhandler.py | 73 +++++++++++++++++++---------- tests/test_conversationhandler.py | 47 +++++++++++++++++++ 2 files changed, 94 insertions(+), 26 deletions(-) diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index a53f4e81c..67da65fe4 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -20,6 +20,7 @@ import logging import warnings +from threading import Lock from telegram import Update from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler, @@ -184,7 +185,9 @@ class ConversationHandler(Handler): self.map_to_parent = map_to_parent self.timeout_jobs = dict() + self._timeout_jobs_lock = Lock() self.conversations = dict() + self._conversations_lock = Lock() self.logger = logging.getLogger(__name__) @@ -262,7 +265,8 @@ class ConversationHandler(Handler): return None key = self._get_key(update) - state = self.conversations.get(key) + with self._conversations_lock: + state = self.conversations.get(key) # Resolve promises if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise): @@ -281,7 +285,8 @@ class ConversationHandler(Handler): if res is None and old_state is None: res = self.END self.update_state(res, key) - state = self.conversations.get(key) + with self._conversations_lock: + state = self.conversations.get(key) else: handlers = self.states.get(self.WAITING, []) for handler in handlers: @@ -340,15 +345,22 @@ class ConversationHandler(Handler): """ conversation_key, handler, check_result = check_result - new_state = handler.handle_update(update, dispatcher, check_result, context) - timeout_job = self.timeout_jobs.pop(conversation_key, None) - if timeout_job is not None: - timeout_job.schedule_removal() - if self.conversation_timeout and new_state != self.END: - self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once( - self._trigger_timeout, self.conversation_timeout, - context=_ConversationTimeoutContext(conversation_key, update, dispatcher)) + with self._timeout_jobs_lock: + # Remove the old timeout job (if present) + timeout_job = self.timeout_jobs.pop(conversation_key, None) + + if timeout_job is not None: + timeout_job.schedule_removal() + + new_state = handler.handle_update(update, dispatcher, check_result, context) + + with self._timeout_jobs_lock: + if self.conversation_timeout and new_state != self.END: + # Add the new timeout job + self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once( + self._trigger_timeout, self.conversation_timeout, + context=_ConversationTimeoutContext(conversation_key, update, dispatcher)) if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: self.update_state(self.END, conversation_key) @@ -358,33 +370,42 @@ class ConversationHandler(Handler): def update_state(self, new_state, key): if new_state == self.END: - if key in self.conversations: - # If there is no key in conversations, nothing is done. - del self.conversations[key] - if self.persistent: - self.persistence.update_conversation(self.name, key, None) + with self._conversations_lock: + if key in self.conversations: + # If there is no key in conversations, nothing is done. + del self.conversations[key] + if self.persistent: + self.persistence.update_conversation(self.name, key, None) elif isinstance(new_state, Promise): - self.conversations[key] = (self.conversations.get(key), new_state) - if self.persistent: - self.persistence.update_conversation(self.name, key, - (self.conversations.get(key), new_state)) + with self._conversations_lock: + self.conversations[key] = (self.conversations.get(key), new_state) + if self.persistent: + self.persistence.update_conversation(self.name, key, + (self.conversations.get(key), new_state)) elif new_state is not None: - self.conversations[key] = new_state - if self.persistent: - self.persistence.update_conversation(self.name, key, new_state) + with self._conversations_lock: + self.conversations[key] = new_state + if self.persistent: + self.persistence.update_conversation(self.name, key, new_state) def _trigger_timeout(self, context, job=None): self.logger.debug('conversation timeout was triggered!') # Backward compatibility with bots that do not use CallbackContext if isinstance(context, CallbackContext): - context = context.job.context - else: - context = job.context + job = context.job + + context = job.context + + with self._timeout_jobs_lock: + found_job = self.timeout_jobs[context.conversation_key] + if found_job is not job: + # The timeout has been canceled in handle_update + return + del self.timeout_jobs[context.conversation_key] - del self.timeout_jobs[context.conversation_key] handlers = self.states.get(self.TIMEOUT, []) for handler in handlers: check = handler.check_update(context.update) diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index 269cb895d..d82d10cbe 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -613,6 +613,53 @@ class TestConversationHandler(object): assert handler.conversations.get((self.group.id, user1.id)) is None assert not self.is_timeout + def test_conversation_timeout_cancel_conflict(self, dp, bot, user1): + # Start state machine, wait half the timeout, + # then call a callback that takes more than the timeout + # t=0 /start (timeout=.5) + # t=.25 /slowbrew (sleep .5) + # | t=.5 original timeout (should not execute) + # | t=.75 /slowbrew returns (timeout=1.25) + # t=1.25 timeout + + def slowbrew(_bot, update): + sleep(0.25) + # Let's give to the original timeout a chance to execute + dp.job_queue.tick() + sleep(0.25) + # By returning None we do not override the conversation state so + # we can see if the timeout has been executed + + states = self.states + states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew)) + states.update({ConversationHandler.TIMEOUT: [ + MessageHandler(None, self.passout2) + ]}) + + handler = ConversationHandler(entry_points=self.entry_points, states=states, + fallbacks=self.fallbacks, conversation_timeout=0.5) + dp.add_handler(handler) + + # CommandHandler timeout + message = Message(0, user1, None, self.group, text='/start', + entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, + length=len('/start'))], + bot=bot) + dp.process_update(Update(update_id=0, message=message)) + sleep(0.25) + dp.job_queue.tick() + message.text = '/slowbrew' + message.entities[0].length = len('/slowbrew') + dp.process_update(Update(update_id=0, message=message)) + dp.job_queue.tick() + assert handler.conversations.get((self.group.id, user1.id)) is not None + assert not self.is_timeout + + sleep(0.5) + dp.job_queue.tick() + assert handler.conversations.get((self.group.id, user1.id)) is None + assert self.is_timeout + def test_per_message_warning_is_only_shown_once(self, recwarn): ConversationHandler( entry_points=self.entry_points,