🔀 Add mutex protection on ConversationHandler (#1533)

* Add mutex protection on ConversationHandler

* Remove timeout job before child update

* Make locks private

* Add conversation timeout conflict test
This commit is contained in:
Lorenzo Rossi 2019-10-17 00:03:53 +02:00 committed by Jannes Höke
parent 7152b5aaf9
commit 3d8771bbdf
2 changed files with 94 additions and 26 deletions

View file

@ -20,6 +20,7 @@
import logging
import warnings
from threading import Lock
from telegram import Update
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
@ -184,7 +185,9 @@ class ConversationHandler(Handler):
self.map_to_parent = map_to_parent
self.timeout_jobs = dict()
self._timeout_jobs_lock = Lock()
self.conversations = dict()
self._conversations_lock = Lock()
self.logger = logging.getLogger(__name__)
@ -262,7 +265,8 @@ class ConversationHandler(Handler):
return None
key = self._get_key(update)
state = self.conversations.get(key)
with self._conversations_lock:
state = self.conversations.get(key)
# Resolve promises
if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise):
@ -281,7 +285,8 @@ class ConversationHandler(Handler):
if res is None and old_state is None:
res = self.END
self.update_state(res, key)
state = self.conversations.get(key)
with self._conversations_lock:
state = self.conversations.get(key)
else:
handlers = self.states.get(self.WAITING, [])
for handler in handlers:
@ -340,15 +345,22 @@ class ConversationHandler(Handler):
"""
conversation_key, handler, check_result = check_result
new_state = handler.handle_update(update, dispatcher, check_result, context)
timeout_job = self.timeout_jobs.pop(conversation_key, None)
if timeout_job is not None:
timeout_job.schedule_removal()
if self.conversation_timeout and new_state != self.END:
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
self._trigger_timeout, self.conversation_timeout,
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
with self._timeout_jobs_lock:
# Remove the old timeout job (if present)
timeout_job = self.timeout_jobs.pop(conversation_key, None)
if timeout_job is not None:
timeout_job.schedule_removal()
new_state = handler.handle_update(update, dispatcher, check_result, context)
with self._timeout_jobs_lock:
if self.conversation_timeout and new_state != self.END:
# Add the new timeout job
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
self._trigger_timeout, self.conversation_timeout,
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
self.update_state(self.END, conversation_key)
@ -358,33 +370,42 @@ class ConversationHandler(Handler):
def update_state(self, new_state, key):
if new_state == self.END:
if key in self.conversations:
# If there is no key in conversations, nothing is done.
del self.conversations[key]
if self.persistent:
self.persistence.update_conversation(self.name, key, None)
with self._conversations_lock:
if key in self.conversations:
# If there is no key in conversations, nothing is done.
del self.conversations[key]
if self.persistent:
self.persistence.update_conversation(self.name, key, None)
elif isinstance(new_state, Promise):
self.conversations[key] = (self.conversations.get(key), new_state)
if self.persistent:
self.persistence.update_conversation(self.name, key,
(self.conversations.get(key), new_state))
with self._conversations_lock:
self.conversations[key] = (self.conversations.get(key), new_state)
if self.persistent:
self.persistence.update_conversation(self.name, key,
(self.conversations.get(key), new_state))
elif new_state is not None:
self.conversations[key] = new_state
if self.persistent:
self.persistence.update_conversation(self.name, key, new_state)
with self._conversations_lock:
self.conversations[key] = new_state
if self.persistent:
self.persistence.update_conversation(self.name, key, new_state)
def _trigger_timeout(self, context, job=None):
self.logger.debug('conversation timeout was triggered!')
# Backward compatibility with bots that do not use CallbackContext
if isinstance(context, CallbackContext):
context = context.job.context
else:
context = job.context
job = context.job
context = job.context
with self._timeout_jobs_lock:
found_job = self.timeout_jobs[context.conversation_key]
if found_job is not job:
# The timeout has been canceled in handle_update
return
del self.timeout_jobs[context.conversation_key]
del self.timeout_jobs[context.conversation_key]
handlers = self.states.get(self.TIMEOUT, [])
for handler in handlers:
check = handler.check_update(context.update)

View file

@ -613,6 +613,53 @@ class TestConversationHandler(object):
assert handler.conversations.get((self.group.id, user1.id)) is None
assert not self.is_timeout
def test_conversation_timeout_cancel_conflict(self, dp, bot, user1):
# Start state machine, wait half the timeout,
# then call a callback that takes more than the timeout
# t=0 /start (timeout=.5)
# t=.25 /slowbrew (sleep .5)
# | t=.5 original timeout (should not execute)
# | t=.75 /slowbrew returns (timeout=1.25)
# t=1.25 timeout
def slowbrew(_bot, update):
sleep(0.25)
# Let's give to the original timeout a chance to execute
dp.job_queue.tick()
sleep(0.25)
# By returning None we do not override the conversation state so
# we can see if the timeout has been executed
states = self.states
states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew))
states.update({ConversationHandler.TIMEOUT: [
MessageHandler(None, self.passout2)
]})
handler = ConversationHandler(entry_points=self.entry_points, states=states,
fallbacks=self.fallbacks, conversation_timeout=0.5)
dp.add_handler(handler)
# CommandHandler timeout
message = Message(0, user1, None, self.group, text='/start',
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0,
length=len('/start'))],
bot=bot)
dp.process_update(Update(update_id=0, message=message))
sleep(0.25)
dp.job_queue.tick()
message.text = '/slowbrew'
message.entities[0].length = len('/slowbrew')
dp.process_update(Update(update_id=0, message=message))
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is not None
assert not self.is_timeout
sleep(0.5)
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None
assert self.is_timeout
def test_per_message_warning_is_only_shown_once(self, recwarn):
ConversationHandler(
entry_points=self.entry_points,