diff --git a/telegram/ext/callbackcontext.py b/telegram/ext/callbackcontext.py index 099f9c15a..952d4499f 100644 --- a/telegram/ext/callbackcontext.py +++ b/telegram/ext/callbackcontext.py @@ -28,6 +28,20 @@ class CallbackContext(object): :attr:`telegram.ext.Dispatcher.add_error_handler` or to the callback of a :class:`telegram.ext.Job`. + Note: + :class:`telegram.ext.Dispatcher` will create a single context for an entire update. This + means that if you got 2 handlers in different groups and they both get called, they will + get passed the same `CallbackContext` object (of course with proper attributes like + `.matches` differing). This allows you to add custom attributes in a lower handler group + callback, and then subsequently access those attributes in a higher handler group callback. + Note that the attributes on `CallbackContext` might change in the future, so make sure to + use a fairly unique name for the attributes. + + Warning: + Do not combine custom attributes and @run_async. Due to how @run_async works, it will + almost certainly execute the callbacks for an update out of order, and the attributes + that you think you added will not be present. + Attributes: chat_data (:obj:`dict`, optional): A dict that can be used to keep any data in. For each update from the same chat it will be the same ``dict``. diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index e843759ad..183470029 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -317,7 +317,7 @@ class ConversationHandler(Handler): return key, handler, check - def handle_update(self, update, dispatcher, check_result): + def handle_update(self, update, dispatcher, check_result, context=None): """Send the update to the callback for the current state and Handler Args: @@ -328,7 +328,7 @@ class ConversationHandler(Handler): """ conversation_key, handler, check_result = check_result - new_state = handler.handle_update(update, dispatcher, check_result) + new_state = handler.handle_update(update, dispatcher, check_result, context) timeout_job = self.timeout_jobs.pop(conversation_key, None) if timeout_job is not None: diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index feda9b9be..9f8f7c652 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -43,14 +43,16 @@ DEFAULT_GROUP = 0 def run_async(func): - """Function decorator that will run the function in a new thread. + """ + Function decorator that will run the function in a new thread. Will run :attr:`telegram.ext.Dispatcher.run_async`. Using this decorator is only possible when only a single Dispatcher exist in the system. - Note: Use this decorator to run handlers asynchronously. - + Warning: + If you're using @run_async you cannot rely on adding custom attributes to + :class:`telegram.ext.CallbackContext`s. See its docs for more info. """ @wraps(func) @@ -210,6 +212,10 @@ class Dispatcher(object): def run_async(self, func, *args, **kwargs): """Queue a function (with given args/kwargs) to be run asynchronously. + Warning: + If you're using @run_async you cannot rely on adding custom attributes to + :class:`telegram.ext.CallbackContext`s. See its docs for more info. + Args: func (:obj:`callable`): The function to run in the thread. *args (:obj:`tuple`, optional): Arguments to `func`. @@ -315,12 +321,16 @@ class Dispatcher(object): self.logger.exception('An uncaught error was raised while handling the error') return + context = None + for group in self.groups: try: for handler in self.handlers[group]: check = handler.check_update(update) if check is not None and check is not False: - handler.handle_update(update, self, check) + if not context and self.use_context: + context = CallbackContext.from_update(update, self) + handler.handle_update(update, self, check, context) if self.persistence and isinstance(update, Update): if self.persistence.store_chat_data and update.effective_chat: chat_id = update.effective_chat.id diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index caa3f007c..2201943ff 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the base class for handlers as used by the Dispatcher.""" -from telegram.ext.callbackcontext import CallbackContext class Handler(object): @@ -99,7 +98,7 @@ class Handler(object): """ raise NotImplementedError - def handle_update(self, update, dispatcher, check_result): + def handle_update(self, update, dispatcher, check_result, context=None): """ This method is called if it was determined that an update should indeed be handled by this instance. Calls :attr:`self.callback` along with its respectful @@ -113,8 +112,7 @@ class Handler(object): check_result: The result from :attr:`check_update`. """ - if dispatcher.use_context: - context = CallbackContext.from_update(update, dispatcher) + if context: self.collect_additional_context(context, update, dispatcher, check_result) return self.callback(update, context) else: diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index e623dcafd..73f1df3db 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -80,6 +80,26 @@ class TestDispatcher(object): and isinstance(context.error, TelegramError)): self.received = context.error.message + def test_one_context_per_update(self, cdp): + def one(update, context): + if update.message.text == 'test': + context.my_flag = True + + def two(update, context): + if update.message.text == 'test': + if not hasattr(context, 'my_flag'): + pytest.fail() + else: + if hasattr(context, 'my_flag'): + pytest.fail() + + cdp.add_handler(MessageHandler(Filters.regex('test'), one), group=1) + cdp.add_handler(MessageHandler(None, two), group=2) + u = Update(1, Message(1, None, None, None, text='test')) + cdp.process_update(u) + u.message.text = 'something' + cdp.process_update(u) + def test_error_handler(self, dp): dp.add_error_handler(self.error_handler) error = TelegramError('Unauthorized.')