mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-11-25 16:46:35 +01:00
🔀 Add mutex protection on ConversationHandler (#1533)
* Add mutex protection on ConversationHandler * Remove timeout job before child update * Make locks private * Add conversation timeout conflict test
This commit is contained in:
parent
7152b5aaf9
commit
3d8771bbdf
2 changed files with 94 additions and 26 deletions
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
||||||
|
@ -184,7 +185,9 @@ class ConversationHandler(Handler):
|
||||||
self.map_to_parent = map_to_parent
|
self.map_to_parent = map_to_parent
|
||||||
|
|
||||||
self.timeout_jobs = dict()
|
self.timeout_jobs = dict()
|
||||||
|
self._timeout_jobs_lock = Lock()
|
||||||
self.conversations = dict()
|
self.conversations = dict()
|
||||||
|
self._conversations_lock = Lock()
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -262,6 +265,7 @@ class ConversationHandler(Handler):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
key = self._get_key(update)
|
key = self._get_key(update)
|
||||||
|
with self._conversations_lock:
|
||||||
state = self.conversations.get(key)
|
state = self.conversations.get(key)
|
||||||
|
|
||||||
# Resolve promises
|
# Resolve promises
|
||||||
|
@ -281,6 +285,7 @@ class ConversationHandler(Handler):
|
||||||
if res is None and old_state is None:
|
if res is None and old_state is None:
|
||||||
res = self.END
|
res = self.END
|
||||||
self.update_state(res, key)
|
self.update_state(res, key)
|
||||||
|
with self._conversations_lock:
|
||||||
state = self.conversations.get(key)
|
state = self.conversations.get(key)
|
||||||
else:
|
else:
|
||||||
handlers = self.states.get(self.WAITING, [])
|
handlers = self.states.get(self.WAITING, [])
|
||||||
|
@ -340,12 +345,19 @@ class ConversationHandler(Handler):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
conversation_key, handler, check_result = check_result
|
conversation_key, handler, check_result = check_result
|
||||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
|
||||||
|
with self._timeout_jobs_lock:
|
||||||
|
# Remove the old timeout job (if present)
|
||||||
timeout_job = self.timeout_jobs.pop(conversation_key, None)
|
timeout_job = self.timeout_jobs.pop(conversation_key, None)
|
||||||
|
|
||||||
if timeout_job is not None:
|
if timeout_job is not None:
|
||||||
timeout_job.schedule_removal()
|
timeout_job.schedule_removal()
|
||||||
|
|
||||||
|
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||||
|
|
||||||
|
with self._timeout_jobs_lock:
|
||||||
if self.conversation_timeout and new_state != self.END:
|
if self.conversation_timeout and new_state != self.END:
|
||||||
|
# Add the new timeout job
|
||||||
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
|
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
|
||||||
self._trigger_timeout, self.conversation_timeout,
|
self._trigger_timeout, self.conversation_timeout,
|
||||||
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
|
context=_ConversationTimeoutContext(conversation_key, update, dispatcher))
|
||||||
|
@ -358,6 +370,7 @@ class ConversationHandler(Handler):
|
||||||
|
|
||||||
def update_state(self, new_state, key):
|
def update_state(self, new_state, key):
|
||||||
if new_state == self.END:
|
if new_state == self.END:
|
||||||
|
with self._conversations_lock:
|
||||||
if key in self.conversations:
|
if key in self.conversations:
|
||||||
# If there is no key in conversations, nothing is done.
|
# If there is no key in conversations, nothing is done.
|
||||||
del self.conversations[key]
|
del self.conversations[key]
|
||||||
|
@ -365,12 +378,14 @@ class ConversationHandler(Handler):
|
||||||
self.persistence.update_conversation(self.name, key, None)
|
self.persistence.update_conversation(self.name, key, None)
|
||||||
|
|
||||||
elif isinstance(new_state, Promise):
|
elif isinstance(new_state, Promise):
|
||||||
|
with self._conversations_lock:
|
||||||
self.conversations[key] = (self.conversations.get(key), new_state)
|
self.conversations[key] = (self.conversations.get(key), new_state)
|
||||||
if self.persistent:
|
if self.persistent:
|
||||||
self.persistence.update_conversation(self.name, key,
|
self.persistence.update_conversation(self.name, key,
|
||||||
(self.conversations.get(key), new_state))
|
(self.conversations.get(key), new_state))
|
||||||
|
|
||||||
elif new_state is not None:
|
elif new_state is not None:
|
||||||
|
with self._conversations_lock:
|
||||||
self.conversations[key] = new_state
|
self.conversations[key] = new_state
|
||||||
if self.persistent:
|
if self.persistent:
|
||||||
self.persistence.update_conversation(self.name, key, new_state)
|
self.persistence.update_conversation(self.name, key, new_state)
|
||||||
|
@ -380,11 +395,17 @@ class ConversationHandler(Handler):
|
||||||
|
|
||||||
# Backward compatibility with bots that do not use CallbackContext
|
# Backward compatibility with bots that do not use CallbackContext
|
||||||
if isinstance(context, CallbackContext):
|
if isinstance(context, CallbackContext):
|
||||||
context = context.job.context
|
job = context.job
|
||||||
else:
|
|
||||||
context = job.context
|
context = job.context
|
||||||
|
|
||||||
|
with self._timeout_jobs_lock:
|
||||||
|
found_job = self.timeout_jobs[context.conversation_key]
|
||||||
|
if found_job is not job:
|
||||||
|
# The timeout has been canceled in handle_update
|
||||||
|
return
|
||||||
del self.timeout_jobs[context.conversation_key]
|
del self.timeout_jobs[context.conversation_key]
|
||||||
|
|
||||||
handlers = self.states.get(self.TIMEOUT, [])
|
handlers = self.states.get(self.TIMEOUT, [])
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
check = handler.check_update(context.update)
|
check = handler.check_update(context.update)
|
||||||
|
|
|
@ -613,6 +613,53 @@ class TestConversationHandler(object):
|
||||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||||
assert not self.is_timeout
|
assert not self.is_timeout
|
||||||
|
|
||||||
|
def test_conversation_timeout_cancel_conflict(self, dp, bot, user1):
|
||||||
|
# Start state machine, wait half the timeout,
|
||||||
|
# then call a callback that takes more than the timeout
|
||||||
|
# t=0 /start (timeout=.5)
|
||||||
|
# t=.25 /slowbrew (sleep .5)
|
||||||
|
# | t=.5 original timeout (should not execute)
|
||||||
|
# | t=.75 /slowbrew returns (timeout=1.25)
|
||||||
|
# t=1.25 timeout
|
||||||
|
|
||||||
|
def slowbrew(_bot, update):
|
||||||
|
sleep(0.25)
|
||||||
|
# Let's give to the original timeout a chance to execute
|
||||||
|
dp.job_queue.tick()
|
||||||
|
sleep(0.25)
|
||||||
|
# By returning None we do not override the conversation state so
|
||||||
|
# we can see if the timeout has been executed
|
||||||
|
|
||||||
|
states = self.states
|
||||||
|
states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew))
|
||||||
|
states.update({ConversationHandler.TIMEOUT: [
|
||||||
|
MessageHandler(None, self.passout2)
|
||||||
|
]})
|
||||||
|
|
||||||
|
handler = ConversationHandler(entry_points=self.entry_points, states=states,
|
||||||
|
fallbacks=self.fallbacks, conversation_timeout=0.5)
|
||||||
|
dp.add_handler(handler)
|
||||||
|
|
||||||
|
# CommandHandler timeout
|
||||||
|
message = Message(0, user1, None, self.group, text='/start',
|
||||||
|
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0,
|
||||||
|
length=len('/start'))],
|
||||||
|
bot=bot)
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
sleep(0.25)
|
||||||
|
dp.job_queue.tick()
|
||||||
|
message.text = '/slowbrew'
|
||||||
|
message.entities[0].length = len('/slowbrew')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
dp.job_queue.tick()
|
||||||
|
assert handler.conversations.get((self.group.id, user1.id)) is not None
|
||||||
|
assert not self.is_timeout
|
||||||
|
|
||||||
|
sleep(0.5)
|
||||||
|
dp.job_queue.tick()
|
||||||
|
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||||
|
assert self.is_timeout
|
||||||
|
|
||||||
def test_per_message_warning_is_only_shown_once(self, recwarn):
|
def test_per_message_warning_is_only_shown_once(self, recwarn):
|
||||||
ConversationHandler(
|
ConversationHandler(
|
||||||
entry_points=self.entry_points,
|
entry_points=self.entry_points,
|
||||||
|
|
Loading…
Reference in a new issue