#!/usr/bin/env python # # A library that provides a Python interface to the Telegram Bot API # Copyright (C) 2015-2017 # Leandro Toledo de Souza # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser Public License for more details. # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. from queue import Queue from threading import current_thread from time import sleep import pytest from telegram import TelegramError, Message, User, Chat, Update from telegram.ext import MessageHandler, Filters, CommandHandler from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerContinue, \ DispatcherHandlerStop from tests.conftest import create_dp @pytest.fixture(scope='function') def dp2(bot): for dp in create_dp(bot): yield dp class TestDispatcher(object): message_update = Update(1, message=Message(1, User(1, ''), None, Chat(1, ''), text='Text')) received = None count = 0 @pytest.fixture(autouse=True) def reset(self): self.received = None self.count = 0 def error_handler(self, bot, update, error): self.received = error.message def callback_increase_count(self, bot, update): self.count += 1 def callback_set_count(self, count): def callback(bot, update): self.count = count return callback def callback_raise_error(self, bot, update): raise TelegramError(update.message.text) def callback_if_not_update_queue(self, bot, update, update_queue=None): if update_queue is not None: self.received = update.message def test_error_handler(self, dp): dp.add_error_handler(self.error_handler) error = TelegramError('Unauthorized.') dp.update_queue.put(error) sleep(.1) assert self.received == 'Unauthorized.' # Remove handler dp.remove_error_handler(self.error_handler) self.reset() dp.update_queue.put(error) sleep(.1) assert self.received is None def test_run_async_multiple(self, bot, dp, dp2): def get_dispatcher_name(q): q.put(current_thread().name) q1 = Queue() q2 = Queue() dp.run_async(get_dispatcher_name, q1) dp2.run_async(get_dispatcher_name, q2) sleep(.1) name1 = q1.get() name2 = q2.get() assert name1 != name2 def test_multiple_run_async_decorator(self, dp, dp2): # Make sure we got two dispatchers and that they are not the same assert isinstance(dp, Dispatcher) assert isinstance(dp2, Dispatcher) assert dp is not dp2 @run_async def must_raise_runtime_error(): pass with pytest.raises(RuntimeError): must_raise_runtime_error() def test_run_async_with_args(self, dp): dp.add_handler(MessageHandler(Filters.all, run_async(self.callback_if_not_update_queue), pass_update_queue=True)) dp.update_queue.put(self.message_update) sleep(.1) assert self.received == self.message_update.message def test_error_in_handler(self, dp): dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error)) dp.add_error_handler(self.error_handler) dp.update_queue.put(self.message_update) sleep(.1) assert self.received == self.message_update.message.text def test_add_remove_handler(self, dp): handler = MessageHandler(Filters.all, self.callback_increase_count) dp.add_handler(handler) dp.update_queue.put(self.message_update) sleep(.1) assert self.count == 1 dp.remove_handler(handler) dp.update_queue.put(self.message_update) assert self.count == 1 def test_add_remove_handler_non_default_group(self, dp): handler = MessageHandler(Filters.all, self.callback_increase_count) dp.add_handler(handler, group=2) with pytest.raises(KeyError): dp.remove_handler(handler) dp.remove_handler(handler, group=2) def test_error_start_twice(self, dp): assert dp.running dp.start() def test_handler_order_in_group(self, dp): dp.add_handler(MessageHandler(Filters.photo, self.callback_set_count(1))) dp.add_handler(MessageHandler(Filters.all, self.callback_set_count(2))) dp.add_handler(MessageHandler(Filters.text, self.callback_set_count(3))) dp.update_queue.put(self.message_update) sleep(.1) assert self.count == 2 def test_groups(self, dp): dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count)) dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=2) dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=-1) dp.update_queue.put(self.message_update) sleep(.1) assert self.count == 3 def test_add_handler_errors(self, dp): handler = 'not a handler' with pytest.raises(TypeError, match='handler is not an instance of'): dp.add_handler(handler) handler = MessageHandler(Filters.photo, self.callback_set_count(1)) with pytest.raises(TypeError, match='group is not int'): dp.add_handler(handler, 'one') def test_handler_flow_continue(self, bot, dp): passed = [] def start1(b, u): passed.append('start1') raise DispatcherHandlerContinue def start2(b, u): passed.append('start2') def start3(b, u): passed.append('start3') def error(b, u, e): passed.append('error') passed.append(e) update = Update(1, message=Message(1, None, None, None, text='/start', bot=bot)) # If Continue raised next handler should be proceed. passed = [] dp.add_handler(CommandHandler('start', start1)) dp.add_handler(CommandHandler('start', start2)) dp.process_update(update) assert passed == ['start1', 'start2'] def test_dispatcher_handler_flow_stop(self, dp, bot): passed = [] def start1(b, u): passed.append('start1') raise DispatcherHandlerStop def start2(b, u): passed.append('start2') def start3(b, u): passed.append('start3') def error(b, u, e): passed.append('error') passed.append(e) update = Update(1, message=Message(1, None, None, None, text='/start', bot=bot)) # If Stop raised handlers in other groups should not be called. passed = [] dp.add_handler(CommandHandler('start', start1), 1) dp.add_handler(CommandHandler('start', start3), 1) dp.add_handler(CommandHandler('start', start2), 2) dp.process_update(update) assert passed == ['start1']