🔀 Add mutex protection on ConversationHandler (#1533)

* Add mutex protection on ConversationHandler

* Remove timeout job before child update

* Make locks private

* Add conversation timeout conflict test
This commit is contained in:
Lorenzo Rossi 2019-10-17 00:03:53 +02:00 committed by Jannes Höke
parent 7152b5aaf9
commit 3d8771bbdf
2 changed files with 94 additions and 26 deletions

View file

@ -20,6 +20,7 @@
import logging import logging
import warnings import warnings
from threading import Lock
from telegram import Update from telegram import Update
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler, from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
@ -184,7 +185,9 @@ class ConversationHandler(Handler):
self.map_to_parent = map_to_parent self.map_to_parent = map_to_parent
self.timeout_jobs = dict() self.timeout_jobs = dict()
self._timeout_jobs_lock = Lock()
self.conversations = dict() self.conversations = dict()
self._conversations_lock = Lock()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@ -262,6 +265,7 @@ class ConversationHandler(Handler):
return None return None
key = self._get_key(update) key = self._get_key(update)
with self._conversations_lock:
state = self.conversations.get(key) state = self.conversations.get(key)
# Resolve promises # Resolve promises
@ -281,6 +285,7 @@ class ConversationHandler(Handler):
if res is None and old_state is None: if res is None and old_state is None:
res = self.END res = self.END
self.update_state(res, key) self.update_state(res, key)
with self._conversations_lock:
state = self.conversations.get(key) state = self.conversations.get(key)
else: else:
handlers = self.states.get(self.WAITING, []) handlers = self.states.get(self.WAITING, [])
@ -340,12 +345,19 @@ class ConversationHandler(Handler):
""" """
conversation_key, handler, check_result = check_result conversation_key, handler, check_result = check_result
new_state = handler.handle_update(update, dispatcher, check_result, context)
with self._timeout_jobs_lock:
# Remove the old timeout job (if present)
timeout_job = self.timeout_jobs.pop(conversation_key, None) timeout_job = self.timeout_jobs.pop(conversation_key, None)
if timeout_job is not None: if timeout_job is not None:
timeout_job.schedule_removal() timeout_job.schedule_removal()
new_state = handler.handle_update(update, dispatcher, check_result, context)
with self._timeout_jobs_lock:
if self.conversation_timeout and new_state != self.END: if self.conversation_timeout and new_state != self.END:
# Add the new timeout job
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once( self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
self._trigger_timeout, self.conversation_timeout, self._trigger_timeout, self.conversation_timeout,
context=_ConversationTimeoutContext(conversation_key, update, dispatcher)) context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
@ -358,6 +370,7 @@ class ConversationHandler(Handler):
def update_state(self, new_state, key): def update_state(self, new_state, key):
if new_state == self.END: if new_state == self.END:
with self._conversations_lock:
if key in self.conversations: if key in self.conversations:
# If there is no key in conversations, nothing is done. # If there is no key in conversations, nothing is done.
del self.conversations[key] del self.conversations[key]
@ -365,12 +378,14 @@ class ConversationHandler(Handler):
self.persistence.update_conversation(self.name, key, None) self.persistence.update_conversation(self.name, key, None)
elif isinstance(new_state, Promise): elif isinstance(new_state, Promise):
with self._conversations_lock:
self.conversations[key] = (self.conversations.get(key), new_state) self.conversations[key] = (self.conversations.get(key), new_state)
if self.persistent: if self.persistent:
self.persistence.update_conversation(self.name, key, self.persistence.update_conversation(self.name, key,
(self.conversations.get(key), new_state)) (self.conversations.get(key), new_state))
elif new_state is not None: elif new_state is not None:
with self._conversations_lock:
self.conversations[key] = new_state self.conversations[key] = new_state
if self.persistent: if self.persistent:
self.persistence.update_conversation(self.name, key, new_state) self.persistence.update_conversation(self.name, key, new_state)
@ -380,11 +395,17 @@ class ConversationHandler(Handler):
# Backward compatibility with bots that do not use CallbackContext # Backward compatibility with bots that do not use CallbackContext
if isinstance(context, CallbackContext): if isinstance(context, CallbackContext):
context = context.job.context job = context.job
else:
context = job.context context = job.context
with self._timeout_jobs_lock:
found_job = self.timeout_jobs[context.conversation_key]
if found_job is not job:
# The timeout has been canceled in handle_update
return
del self.timeout_jobs[context.conversation_key] del self.timeout_jobs[context.conversation_key]
handlers = self.states.get(self.TIMEOUT, []) handlers = self.states.get(self.TIMEOUT, [])
for handler in handlers: for handler in handlers:
check = handler.check_update(context.update) check = handler.check_update(context.update)

View file

@ -613,6 +613,53 @@ class TestConversationHandler(object):
assert handler.conversations.get((self.group.id, user1.id)) is None assert handler.conversations.get((self.group.id, user1.id)) is None
assert not self.is_timeout assert not self.is_timeout
def test_conversation_timeout_cancel_conflict(self, dp, bot, user1):
# Start state machine, wait half the timeout,
# then call a callback that takes more than the timeout
# t=0 /start (timeout=.5)
# t=.25 /slowbrew (sleep .5)
# | t=.5 original timeout (should not execute)
# | t=.75 /slowbrew returns (timeout=1.25)
# t=1.25 timeout
def slowbrew(_bot, update):
sleep(0.25)
# Let's give to the original timeout a chance to execute
dp.job_queue.tick()
sleep(0.25)
# By returning None we do not override the conversation state so
# we can see if the timeout has been executed
states = self.states
states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew))
states.update({ConversationHandler.TIMEOUT: [
MessageHandler(None, self.passout2)
]})
handler = ConversationHandler(entry_points=self.entry_points, states=states,
fallbacks=self.fallbacks, conversation_timeout=0.5)
dp.add_handler(handler)
# CommandHandler timeout
message = Message(0, user1, None, self.group, text='/start',
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0,
length=len('/start'))],
bot=bot)
dp.process_update(Update(update_id=0, message=message))
sleep(0.25)
dp.job_queue.tick()
message.text = '/slowbrew'
message.entities[0].length = len('/slowbrew')
dp.process_update(Update(update_id=0, message=message))
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is not None
assert not self.is_timeout
sleep(0.5)
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None
assert self.is_timeout
def test_per_message_warning_is_only_shown_once(self, recwarn): def test_per_message_warning_is_only_shown_once(self, recwarn):
ConversationHandler( ConversationHandler(
entry_points=self.entry_points, entry_points=self.entry_points,