mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-11-25 08:37:07 +01:00
🔀 Add mutex protection on ConversationHandler (#1533)
* Add mutex protection on ConversationHandler * Remove timeout job before child update * Make locks private * Add conversation timeout conflict test
This commit is contained in:
parent
7152b5aaf9
commit
3d8771bbdf
2 changed files with 94 additions and 26 deletions
|
@ -20,6 +20,7 @@
|
|||
|
||||
import logging
|
||||
import warnings
|
||||
from threading import Lock
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
||||
|
@ -184,7 +185,9 @@ class ConversationHandler(Handler):
|
|||
self.map_to_parent = map_to_parent
|
||||
|
||||
self.timeout_jobs = dict()
|
||||
self._timeout_jobs_lock = Lock()
|
||||
self.conversations = dict()
|
||||
self._conversations_lock = Lock()
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -262,7 +265,8 @@ class ConversationHandler(Handler):
|
|||
return None
|
||||
|
||||
key = self._get_key(update)
|
||||
state = self.conversations.get(key)
|
||||
with self._conversations_lock:
|
||||
state = self.conversations.get(key)
|
||||
|
||||
# Resolve promises
|
||||
if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise):
|
||||
|
@ -281,7 +285,8 @@ class ConversationHandler(Handler):
|
|||
if res is None and old_state is None:
|
||||
res = self.END
|
||||
self.update_state(res, key)
|
||||
state = self.conversations.get(key)
|
||||
with self._conversations_lock:
|
||||
state = self.conversations.get(key)
|
||||
else:
|
||||
handlers = self.states.get(self.WAITING, [])
|
||||
for handler in handlers:
|
||||
|
@ -340,15 +345,22 @@ class ConversationHandler(Handler):
|
|||
|
||||
"""
|
||||
conversation_key, handler, check_result = check_result
|
||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||
timeout_job = self.timeout_jobs.pop(conversation_key, None)
|
||||
|
||||
if timeout_job is not None:
|
||||
timeout_job.schedule_removal()
|
||||
if self.conversation_timeout and new_state != self.END:
|
||||
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
|
||||
self._trigger_timeout, self.conversation_timeout,
|
||||
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
|
||||
with self._timeout_jobs_lock:
|
||||
# Remove the old timeout job (if present)
|
||||
timeout_job = self.timeout_jobs.pop(conversation_key, None)
|
||||
|
||||
if timeout_job is not None:
|
||||
timeout_job.schedule_removal()
|
||||
|
||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||
|
||||
with self._timeout_jobs_lock:
|
||||
if self.conversation_timeout and new_state != self.END:
|
||||
# Add the new timeout job
|
||||
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
|
||||
self._trigger_timeout, self.conversation_timeout,
|
||||
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
|
||||
|
||||
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
|
||||
self.update_state(self.END, conversation_key)
|
||||
|
@ -358,33 +370,42 @@ class ConversationHandler(Handler):
|
|||
|
||||
def update_state(self, new_state, key):
|
||||
if new_state == self.END:
|
||||
if key in self.conversations:
|
||||
# If there is no key in conversations, nothing is done.
|
||||
del self.conversations[key]
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key, None)
|
||||
with self._conversations_lock:
|
||||
if key in self.conversations:
|
||||
# If there is no key in conversations, nothing is done.
|
||||
del self.conversations[key]
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key, None)
|
||||
|
||||
elif isinstance(new_state, Promise):
|
||||
self.conversations[key] = (self.conversations.get(key), new_state)
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key,
|
||||
(self.conversations.get(key), new_state))
|
||||
with self._conversations_lock:
|
||||
self.conversations[key] = (self.conversations.get(key), new_state)
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key,
|
||||
(self.conversations.get(key), new_state))
|
||||
|
||||
elif new_state is not None:
|
||||
self.conversations[key] = new_state
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key, new_state)
|
||||
with self._conversations_lock:
|
||||
self.conversations[key] = new_state
|
||||
if self.persistent:
|
||||
self.persistence.update_conversation(self.name, key, new_state)
|
||||
|
||||
def _trigger_timeout(self, context, job=None):
|
||||
self.logger.debug('conversation timeout was triggered!')
|
||||
|
||||
# Backward compatibility with bots that do not use CallbackContext
|
||||
if isinstance(context, CallbackContext):
|
||||
context = context.job.context
|
||||
else:
|
||||
context = job.context
|
||||
job = context.job
|
||||
|
||||
context = job.context
|
||||
|
||||
with self._timeout_jobs_lock:
|
||||
found_job = self.timeout_jobs[context.conversation_key]
|
||||
if found_job is not job:
|
||||
# The timeout has been canceled in handle_update
|
||||
return
|
||||
del self.timeout_jobs[context.conversation_key]
|
||||
|
||||
del self.timeout_jobs[context.conversation_key]
|
||||
handlers = self.states.get(self.TIMEOUT, [])
|
||||
for handler in handlers:
|
||||
check = handler.check_update(context.update)
|
||||
|
|
|
@ -613,6 +613,53 @@ class TestConversationHandler(object):
|
|||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
assert not self.is_timeout
|
||||
|
||||
def test_conversation_timeout_cancel_conflict(self, dp, bot, user1):
|
||||
# Start state machine, wait half the timeout,
|
||||
# then call a callback that takes more than the timeout
|
||||
# t=0 /start (timeout=.5)
|
||||
# t=.25 /slowbrew (sleep .5)
|
||||
# | t=.5 original timeout (should not execute)
|
||||
# | t=.75 /slowbrew returns (timeout=1.25)
|
||||
# t=1.25 timeout
|
||||
|
||||
def slowbrew(_bot, update):
|
||||
sleep(0.25)
|
||||
# Let's give to the original timeout a chance to execute
|
||||
dp.job_queue.tick()
|
||||
sleep(0.25)
|
||||
# By returning None we do not override the conversation state so
|
||||
# we can see if the timeout has been executed
|
||||
|
||||
states = self.states
|
||||
states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew))
|
||||
states.update({ConversationHandler.TIMEOUT: [
|
||||
MessageHandler(None, self.passout2)
|
||||
]})
|
||||
|
||||
handler = ConversationHandler(entry_points=self.entry_points, states=states,
|
||||
fallbacks=self.fallbacks, conversation_timeout=0.5)
|
||||
dp.add_handler(handler)
|
||||
|
||||
# CommandHandler timeout
|
||||
message = Message(0, user1, None, self.group, text='/start',
|
||||
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0,
|
||||
length=len('/start'))],
|
||||
bot=bot)
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.25)
|
||||
dp.job_queue.tick()
|
||||
message.text = '/slowbrew'
|
||||
message.entities[0].length = len('/slowbrew')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
dp.job_queue.tick()
|
||||
assert handler.conversations.get((self.group.id, user1.id)) is not None
|
||||
assert not self.is_timeout
|
||||
|
||||
sleep(0.5)
|
||||
dp.job_queue.tick()
|
||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
assert self.is_timeout
|
||||
|
||||
def test_per_message_warning_is_only_shown_once(self, recwarn):
|
||||
ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
|
|
Loading…
Reference in a new issue