#!/usr/bin/env python # # A library that provides a Python interface to the Telegram Bot API # Copyright (C) 2015-2024 # Leandro Toledo de Souza # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser Public License for more details. # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """The integration of persistence into the application is tested in test_basepersistence. """ import asyncio import functools import inspect import logging import os import platform import signal import sys import threading import time from collections import defaultdict from pathlib import Path from queue import Queue from random import randrange from threading import Thread from typing import Optional import pytest from telegram import Bot, Chat, Message, MessageEntity, User from telegram.error import TelegramError from telegram.ext import ( Application, ApplicationBuilder, ApplicationHandlerStop, BaseHandler, CallbackContext, CommandHandler, ContextTypes, Defaults, JobQueue, MessageHandler, PicklePersistence, SimpleUpdateProcessor, TypeHandler, Updater, filters, ) from telegram.warnings import PTBDeprecationWarning, PTBUserWarning from tests.auxil.asyncio_helpers import call_after from tests.auxil.build_messages import make_message_update from tests.auxil.files import PROJECT_ROOT_PATH from tests.auxil.monkeypatch import empty_get_updates, return_true from tests.auxil.networking import send_webhook_message from tests.auxil.pytest_classes import PytestApplication, PytestUpdater, make_bot from tests.auxil.slots import mro_slots class CustomContext(CallbackContext): pass class TestApplication: """The integration of persistence into the application is tested in test_basepersistence. """ message_update = make_message_update(message="Text") received = None count = 0 @pytest.fixture(autouse=True, name="reset") def _reset_fixture(self): self.reset() def reset(self): self.received = None self.count = 0 async def error_handler_context(self, update, context): self.received = context.error.message async def error_handler_raise_error(self, update, context): raise Exception("Failing bigly") async def callback_increase_count(self, update, context): self.count += 1 def callback_set_count(self, count, sleep: Optional[float] = None): async def callback(update, context): if sleep: await asyncio.sleep(sleep) self.count = count return callback def callback_raise_error(self, error_message: str): async def callback(update, context): raise TelegramError(error_message) return callback async def callback_received(self, update, context): self.received = update.message async def callback_context(self, update, context): if ( isinstance(context, CallbackContext) and isinstance(context.bot, Bot) and isinstance(context.update_queue, Queue) and isinstance(context.job_queue, JobQueue) and isinstance(context.error, TelegramError) ): self.received = context.error.message async def test_slot_behaviour(self, one_time_bot): async with ApplicationBuilder().bot(one_time_bot).build() as app: for at in app.__slots__: attr = f"_Application{at}" if at.startswith("__") and not at.endswith("__") else at assert getattr(app, attr, "err") != "err", f"got extra slot '{attr}'" assert len(mro_slots(app)) == len(set(mro_slots(app))), "duplicate slot" def test_manual_init_warning(self, recwarn, updater): Application( bot=None, update_queue=None, job_queue=None, persistence=None, context_types=ContextTypes(), updater=updater, update_processor=False, post_init=None, post_shutdown=None, post_stop=None, ) assert len(recwarn) == 1 assert ( str(recwarn[-1].message) == "`Application` instances should be built via the `ApplicationBuilder`." ) assert recwarn[0].category is PTBUserWarning assert recwarn[0].filename == __file__, "stacklevel is incorrect!" @pytest.mark.filterwarnings("ignore: `Application` instances should") def test_init(self, one_time_bot): update_queue = asyncio.Queue() job_queue = JobQueue() persistence = PicklePersistence("file_path") context_types = ContextTypes() update_processor = SimpleUpdateProcessor(1) updater = Updater(bot=one_time_bot, update_queue=update_queue) async def post_init(application: Application) -> None: pass async def post_shutdown(application: Application) -> None: pass async def post_stop(application: Application) -> None: pass app = Application( bot=one_time_bot, update_queue=update_queue, job_queue=job_queue, persistence=persistence, context_types=context_types, updater=updater, update_processor=update_processor, post_init=post_init, post_shutdown=post_shutdown, post_stop=post_stop, ) assert app.bot is one_time_bot assert app.update_queue is update_queue assert app.job_queue is job_queue assert app.persistence is persistence assert app.context_types is context_types assert app.updater is updater assert app.update_queue is updater.update_queue assert app.bot is updater.bot assert app.update_processor is update_processor assert app.post_init is post_init assert app.post_shutdown is post_shutdown assert app.post_stop is post_stop # These should be done by the builder assert app.persistence.bot is None with pytest.raises(RuntimeError, match="No application was set"): app.job_queue.application assert isinstance(app.bot_data, dict) assert isinstance(app.chat_data[1], dict) assert isinstance(app.user_data[1], dict) async def test_repr(self, app): assert repr(app) == f"PytestApplication[bot={app.bot!r}]" def test_job_queue(self, one_time_bot, app, recwarn): expected_warning = ( "No `JobQueue` set up. To use `JobQueue`, you must install PTB via " '`pip install "python-telegram-bot[job-queue]"`.' ) assert app.job_queue is app._job_queue application = ApplicationBuilder().bot(one_time_bot).job_queue(None).build() assert application.job_queue is None assert len(recwarn) == 1 assert str(recwarn[0].message) == expected_warning assert recwarn[0].category is PTBUserWarning assert recwarn[0].filename == __file__, "wrong stacklevel" def test_custom_context_init(self, one_time_bot): cc = ContextTypes( context=CustomContext, user_data=int, chat_data=float, bot_data=complex, ) application = ApplicationBuilder().bot(one_time_bot).context_types(cc).build() assert isinstance(application.user_data[1], int) assert isinstance(application.chat_data[1], float) assert isinstance(application.bot_data, complex) @pytest.mark.parametrize("updater", [True, False]) async def test_initialize(self, one_time_bot, monkeypatch, updater): """Initialization of persistence is tested test_basepersistence""" self.test_flag = set() async def after_initialize_bot(*args, **kwargs): self.test_flag.add("bot") async def after_initialize_update_processor(*args, **kwargs): self.test_flag.add("update_processor") async def after_initialize_updater(*args, **kwargs): self.test_flag.add("updater") update_processor = SimpleUpdateProcessor(1) monkeypatch.setattr(Bot, "initialize", call_after(Bot.initialize, after_initialize_bot)) monkeypatch.setattr( SimpleUpdateProcessor, "initialize", call_after(SimpleUpdateProcessor.initialize, after_initialize_update_processor), ) monkeypatch.setattr( Updater, "initialize", call_after(Updater.initialize, after_initialize_updater) ) if updater: app = ( ApplicationBuilder().bot(one_time_bot).concurrent_updates(update_processor).build() ) await app.initialize() assert self.test_flag == {"bot", "update_processor", "updater"} await app.shutdown() else: app = ( ApplicationBuilder() .bot(one_time_bot) .updater(None) .concurrent_updates(update_processor) .build() ) await app.initialize() assert self.test_flag == {"bot", "update_processor"} await app.shutdown() @pytest.mark.parametrize("updater", [True, False]) async def test_shutdown(self, one_time_bot, monkeypatch, updater): """Shutdown of persistence is tested in test_basepersistence""" self.test_flag = set() def after_bot_shutdown(*args, **kwargs): self.test_flag.add("bot") def after_shutdown_update_processor(*args, **kwargs): self.test_flag.add("update_processor") def after_updater_shutdown(*args, **kwargs): self.test_flag.add("updater") update_processor = SimpleUpdateProcessor(1) monkeypatch.setattr(Bot, "shutdown", call_after(Bot.shutdown, after_bot_shutdown)) monkeypatch.setattr( SimpleUpdateProcessor, "shutdown", call_after(SimpleUpdateProcessor.shutdown, after_shutdown_update_processor), ) monkeypatch.setattr( Updater, "shutdown", call_after(Updater.shutdown, after_updater_shutdown) ) if updater: async with ( ApplicationBuilder().bot(one_time_bot).concurrent_updates(update_processor).build() ): pass assert self.test_flag == {"bot", "update_processor", "updater"} else: async with ( ApplicationBuilder() .bot(one_time_bot) .updater(None) .concurrent_updates(update_processor) .build() ): pass assert self.test_flag == {"bot", "update_processor"} async def test_multiple_inits_and_shutdowns(self, app, monkeypatch): self.received = defaultdict(int) async def after_initialize(*args, **kargs): self.received["init"] += 1 async def after_shutdown(*args, **kwargs): self.received["shutdown"] += 1 monkeypatch.setattr( app.bot, "initialize", call_after(app.bot.initialize, after_initialize) ) monkeypatch.setattr(app.bot, "shutdown", call_after(app.bot.shutdown, after_shutdown)) await app.initialize() await app.initialize() await app.initialize() await app.shutdown() await app.shutdown() await app.shutdown() # 2 instead of 1 since `Updater.initialize` also calls bot.init/shutdown assert self.received["init"] == 2 assert self.received["shutdown"] == 2 async def test_multiple_init_cycles(self, app): # nothing really to assert - this should just not fail async with app: await app.bot.get_me() async with app: await app.bot.get_me() async def test_start_without_initialize(self, app): with pytest.raises(RuntimeError, match="not initialized"): await app.start() async def test_shutdown_while_running(self, app): async with app: await app.start() with pytest.raises(RuntimeError, match="still running"): await app.shutdown() await app.stop() async def test_start_not_running_after_failure(self, one_time_bot, monkeypatch): def start(_): raise Exception("Test Exception") monkeypatch.setattr(JobQueue, "start", start) app = ApplicationBuilder().bot(one_time_bot).job_queue(JobQueue()).build() async with app: with pytest.raises(Exception, match="Test Exception"): await app.start() assert app.running is False async def test_context_manager(self, monkeypatch, app): self.test_flag = set() async def after_initialize(*args, **kwargs): self.test_flag.add("initialize") async def after_shutdown(*args, **kwargs): self.test_flag.add("stop") monkeypatch.setattr( Application, "initialize", call_after(Application.initialize, after_initialize) ) monkeypatch.setattr( Application, "shutdown", call_after(Application.shutdown, after_shutdown) ) async with app: pass assert self.test_flag == {"initialize", "stop"} async def test_context_manager_exception_on_init(self, monkeypatch, app): async def after_initialize(*args, **kwargs): raise RuntimeError("initialize") async def after_shutdown(*args): self.test_flag = "stop" monkeypatch.setattr( Application, "initialize", call_after(Application.initialize, after_initialize) ) monkeypatch.setattr( Application, "shutdown", call_after(Application.shutdown, after_shutdown) ) with pytest.raises(RuntimeError, match="initialize"): async with app: pass assert self.test_flag == "stop" @pytest.mark.parametrize("data", ["chat_data", "user_data"]) def test_chat_user_data_read_only(self, app, data): read_only_data = getattr(app, data) writable_data = getattr(app, f"_{data}") writable_data[123] = 321 assert read_only_data == writable_data with pytest.raises(TypeError): read_only_data[111] = 123 def test_builder(self, app): builder_1 = app.builder() builder_2 = app.builder() assert isinstance(builder_1, ApplicationBuilder) assert isinstance(builder_2, ApplicationBuilder) assert builder_1 is not builder_2 # Make sure that setting a token doesn't raise an exception # i.e. check that the builders are "empty"/new builder_1.token(app.bot.token) builder_2.token(app.bot.token) @pytest.mark.parametrize("job_queue", [True, False]) @pytest.mark.filterwarnings("ignore::telegram.warnings.PTBUserWarning") async def test_start_stop_processing_updates(self, one_time_bot, job_queue, monkeypatch): # TODO: repeat a similar test for create_task, persistence processing and job queue if job_queue: app = ApplicationBuilder().bot(one_time_bot).build() else: app = ApplicationBuilder().bot(one_time_bot).job_queue(None).build() async def callback(u, c): self.received = u monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) monkeypatch.setattr(app.bot, "delete_webhook", return_true) assert not app.running assert not app.updater.running if job_queue: assert not app.job_queue.scheduler.running else: assert app.job_queue is None app.add_handler(TypeHandler(object, callback)) await app.update_queue.put(1) await asyncio.sleep(0.05) assert not app.update_queue.empty() assert self.received is None async with app: await app.start() assert app.running tasks = asyncio.all_tasks() assert any(":update_fetcher" in task.get_name() for task in tasks) if job_queue: assert app.job_queue.scheduler.running else: assert app.job_queue is None # app.start() should not start the updater! assert not app.updater.running await asyncio.sleep(0.05) assert app.update_queue.empty() assert self.received == 1 try: # just in case start_polling times out await app.updater.start_polling() except TelegramError: pytest.xfail("start_polling timed out") else: await app.stop() assert not app.running # app.stop() should not stop the updater! assert app.updater.running if job_queue: assert not app.job_queue.scheduler.running else: assert app.job_queue is None await app.update_queue.put(2) await asyncio.sleep(0.05) assert not app.update_queue.empty() assert self.received != 2 assert self.received == 1 await app.updater.stop() async def test_error_start_stop_twice(self, app): async with app: await app.start() assert app.running with pytest.raises(RuntimeError, match="already running"): await app.start() await app.stop() assert not app.running with pytest.raises(RuntimeError, match="not running"): await app.stop() async def test_one_context_per_update(self, app): self.received = None async def one(update, context): self.received = context async def two(update, context): if update.message.text == "test": if context is not self.received: pytest.fail("Expected same context object, got different") elif context is self.received: pytest.fail("First handler was wrongly called") async with app: app.add_handler(MessageHandler(filters.Regex("test"), one), group=1) app.add_handler(MessageHandler(filters.ALL, two), group=2) u = make_message_update(message="test") await app.process_update(u) self.received = None u = make_message_update(message="something") await app.process_update(u) def test_add_handler_errors(self, app): handler = "not a handler" with pytest.raises(TypeError, match="handler is not an instance of"): app.add_handler(handler) handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) with pytest.raises(TypeError, match="group is not int"): app.add_handler(handler, "one") @pytest.mark.parametrize("group_empty", [True, False]) async def test_add_remove_handler(self, app, group_empty): handler = MessageHandler(filters.ALL, self.callback_increase_count) app.add_handler(handler) if not group_empty: app.add_handler(handler) async with app: await app.start() await app.update_queue.put(self.message_update) await asyncio.sleep(0.05) assert self.count == 1 app.remove_handler(handler) assert (0 in app.handlers) == (not group_empty) await app.update_queue.put(self.message_update) assert self.count == 1 await app.stop() async def test_add_remove_handler_non_default_group(self, app): handler = MessageHandler(filters.ALL, self.callback_increase_count) app.add_handler(handler, group=2) with pytest.raises(KeyError): app.remove_handler(handler) app.remove_handler(handler, group=2) async def test_handler_order_in_group(self, app): app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) app.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) async with app: await app.start() await app.update_queue.put(self.message_update) await asyncio.sleep(0.05) assert self.count == 2 await app.stop() async def test_groups(self, app): app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count)) app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) async with app: await app.start() await app.update_queue.put(self.message_update) await asyncio.sleep(0.05) assert self.count == 3 await app.stop() async def test_add_handlers(self, app): """Tests both add_handler & add_handlers together & confirms the correct insertion order""" msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) app.add_handler(msg_handler_set_count, 1) app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) photo_update = make_message_update(message=Message(2, None, None, photo=(True,))) async with app: await app.start() # Putting updates in the queue calls the callback await app.update_queue.put(self.message_update) await app.update_queue.put(photo_update) await asyncio.sleep(0.05) # sleep is required otherwise there is random behaviour # Test if handler was added to correct group with correct order- assert self.count == 2 assert len(app.handlers[1]) == 3 assert app.handlers[1][0] is msg_handler_set_count # Now lets test add_handlers when `handlers` is a dict- voice_filter_handler_to_check = MessageHandler( filters.VOICE, self.callback_increase_count ) app.add_handlers( handlers={ 1: [ MessageHandler(filters.USER, self.callback_increase_count), voice_filter_handler_to_check, ], -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], } ) user_update = make_message_update( message=Message(3, None, None, from_user=User(1, "s", True)) ) voice_update = make_message_update(message=Message(4, None, None, voice=True)) await app.update_queue.put(user_update) await app.update_queue.put(voice_update) await asyncio.sleep(0.05) assert self.count == 4 assert len(app.handlers[1]) == 5 assert app.handlers[1][-1] is voice_filter_handler_to_check await app.update_queue.put( make_message_update(message=Message(5, None, None, caption="cap")) ) await asyncio.sleep(0.05) assert self.count == 2 assert len(app.handlers[-1]) == 1 # Now lets test the errors which can be produced- with pytest.raises(TypeError, match="The `group` argument"): app.add_handlers({2: [msg_handler_set_count]}, group=0) with pytest.raises(TypeError, match="Handlers for group 3"): app.add_handlers({3: msg_handler_set_count}) with pytest.raises(TypeError, match="The `handlers` argument must be a sequence"): app.add_handlers({msg_handler_set_count}) await app.stop() async def test_check_update(self, app): class TestHandler(BaseHandler): def check_update(_, update: object): self.received = object() def handle_update( _, update, application, check_result, context, ): assert application is app assert check_result is not self.received async with app: app.add_handler(TestHandler("callback")) await app.start() await app.update_queue.put(object()) await asyncio.sleep(0.05) await app.stop() async def test_flow_stop(self, app, one_time_bot): passed = [] async def start1(b, u): passed.append("start1") raise ApplicationHandlerStop async def start2(b, u): passed.append("start2") async def start3(b, u): passed.append("start3") update = make_message_update( message=Message( 1, None, None, None, text="/start", entities=[ MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start")) ], ), ) await one_time_bot.initialize() update.message.set_bot(one_time_bot) async with app: # If ApplicationHandlerStop raised handlers in other groups should not be called. passed = [] app.add_handler(CommandHandler("start", start1), 1) app.add_handler(CommandHandler("start", start3), 1) app.add_handler(CommandHandler("start", start2), 2) await app.process_update(update) assert passed == ["start1"] async def test_flow_stop_by_error_handler(self, app): passed = [] exception = Exception("General exception") async def start1(u, c): passed.append("start1") raise exception async def start2(u, c): passed.append("start2") async def start3(u, c): passed.append("start3") async def error(u, c): passed.append("error") passed.append(c.error) raise ApplicationHandlerStop async with app: # If ApplicationHandlerStop raised handlers in other groups should not be called. passed = [] app.add_error_handler(error) app.add_handler(TypeHandler(object, start1), 1) app.add_handler(TypeHandler(object, start2), 1) app.add_handler(TypeHandler(object, start3), 2) await app.process_update(1) assert passed == ["start1", "error", exception] async def test_error_in_handler_part_1(self, app): app.add_handler( MessageHandler( filters.ALL, self.callback_raise_error(error_message=self.message_update.message.text), ) ) app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(42)), group=1) app.add_error_handler(self.error_handler_context) async with app: await app.start() await app.update_queue.put(self.message_update) await asyncio.sleep(0.05) await app.stop() assert self.received == self.message_update.message.text # Higher groups should still be called assert self.count == 42 async def test_error_in_handler_part_2(self, app, one_time_bot): passed = [] err = Exception("General exception") async def start1(u, c): passed.append("start1") raise err async def start2(u, c): passed.append("start2") async def start3(u, c): passed.append("start3") async def error(u, c): passed.append("error") passed.append(c.error) update = make_message_update( message=Message( 1, None, None, None, text="/start", entities=[ MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start")) ], ), ) await one_time_bot.initialize() update.message.set_bot(one_time_bot) async with app: # If an unhandled exception was caught, no further handlers from the same group should # be called. Also, the error handler should be called and receive the exception passed = [] app.add_handler(CommandHandler("start", start1), 1) app.add_handler(CommandHandler("start", start2), 1) app.add_handler(CommandHandler("start", start3), 2) app.add_error_handler(error) await app.process_update(update) assert passed == ["start1", "error", err, "start3"] @pytest.mark.parametrize("block", [True, False]) async def test_error_handler(self, app, block): app.add_error_handler(self.error_handler_context) app.add_handler(TypeHandler(object, self.callback_raise_error("TestError"), block=block)) async with app: await app.start() await app.update_queue.put(1) await asyncio.sleep(0.05) assert self.received == "TestError" # Remove handler app.remove_error_handler(self.error_handler_context) self.reset() await app.update_queue.put(1) await asyncio.sleep(0.05) assert self.received is None await app.stop() def test_double_add_error_handler(self, app, caplog): app.add_error_handler(self.error_handler_context) with caplog.at_level(logging.DEBUG): app.add_error_handler(self.error_handler_context) assert len(caplog.records) == 1 assert caplog.records[-1].name == "telegram.ext.Application" assert caplog.records[-1].getMessage().startswith("The callback is already registered") async def test_error_handler_that_raises_errors(self, app, caplog): """Make sure that errors raised in error handlers don't break the main loop of the application """ handler_raise_error = TypeHandler( int, self.callback_raise_error(error_message="TestError") ) handler_increase_count = TypeHandler(str, self.callback_increase_count) app.add_error_handler(self.error_handler_raise_error) app.add_handler(handler_raise_error) app.add_handler(handler_increase_count) with caplog.at_level(logging.ERROR): async with app: await app.start() await app.update_queue.put(1) await asyncio.sleep(0.05) assert self.count == 0 assert self.received is None assert len(caplog.records) > 0 assert any( "uncaught error was raised while handling the error with an error_handler" in record.getMessage() and record.name == "telegram.ext.Application" for record in caplog.records ) await app.update_queue.put("1") self.received = None caplog.clear() await asyncio.sleep(0.05) assert self.count == 1 assert self.received is None assert not caplog.records await app.stop() async def test_custom_context_error_handler(self, one_time_bot): async def error_handler(_, context): self.received = ( type(context), type(context.user_data), type(context.chat_data), type(context.bot_data), ) application = ( ApplicationBuilder() .bot(one_time_bot) .context_types( ContextTypes( context=CustomContext, bot_data=int, user_data=float, chat_data=complex ) ) .build() ) application.add_error_handler(error_handler) application.add_handler( MessageHandler(filters.ALL, self.callback_raise_error("TestError")) ) async with application: await application.process_update(self.message_update) await asyncio.sleep(0.05) assert self.received == (CustomContext, float, complex, int) async def test_custom_context_handler_callback(self, one_time_bot): async def callback(_, context): self.received = ( type(context), type(context.user_data), type(context.chat_data), type(context.bot_data), ) application = ( ApplicationBuilder() .bot(one_time_bot) .context_types( ContextTypes( context=CustomContext, bot_data=int, user_data=float, chat_data=complex ) ) .build() ) application.add_handler(MessageHandler(filters.ALL, callback)) async with application: await application.process_update(self.message_update) await asyncio.sleep(0.05) assert self.received == (CustomContext, float, complex, int) @pytest.mark.parametrize( ("check", "expected"), [(True, True), (None, False), (False, False), ({}, True), ("", True), ("check", True)], ) async def test_check_update_handling(self, app, check, expected): class MyHandler(BaseHandler): def check_update(self, update: object): return check async def handle_update( _, update, application, check_result, context, ): await super().handle_update( update=update, application=application, check_result=check_result, context=context, ) self.received = check_result async with app: app.add_handler(MyHandler(self.callback_increase_count)) await app.process_update(1) assert self.count == (1 if expected else 0) if expected: assert self.received == check else: assert self.received is None async def test_non_blocking_handler(self, app): event = asyncio.Event() async def callback(update, context): await event.wait() self.count = 42 app.add_handler(TypeHandler(object, callback, block=False)) app.add_handler(TypeHandler(object, self.callback_increase_count), group=1) async with app: await app.start() await app.update_queue.put(1) task = asyncio.create_task(app.stop()) await asyncio.sleep(0.05) tasks = asyncio.all_tasks() assert any(":process_update_non_blocking" in t.get_name() for t in tasks) assert self.count == 1 # Make sure that app stops only once all non blocking callbacks are done assert not task.done() event.set() await asyncio.sleep(0.05) assert self.count == 42 assert task.done() async def test_non_blocking_handler_applicationhandlerstop(self, app, recwarn): async def callback(update, context): raise ApplicationHandlerStop app.add_handler(TypeHandler(object, callback, block=False)) async with app: await app.start() await app.update_queue.put(1) await asyncio.sleep(0.05) await app.stop() assert len(recwarn) == 1 assert recwarn[0].category is PTBUserWarning assert ( str(recwarn[0].message) == "ApplicationHandlerStop is not supported with handlers running non-blocking." ) assert ( Path(recwarn[0].filename) == PROJECT_ROOT_PATH / "telegram" / "ext" / "_application.py" ), "incorrect stacklevel!" async def test_non_blocking_no_error_handler(self, app, caplog): app.add_handler(TypeHandler(object, self.callback_raise_error("Test error"), block=False)) with caplog.at_level(logging.ERROR): async with app: await app.start() await app.update_queue.put(1) await asyncio.sleep(0.05) assert len(caplog.records) == 1 assert ( caplog.records[-1].getMessage().startswith("No error handlers are registered") ) assert caplog.records[-1].name == "telegram.ext.Application" await app.stop() @pytest.mark.parametrize("handler_block", [True, False]) async def test_non_blocking_error_handler(self, app, handler_block): event = asyncio.Event() async def async_error_handler(update, context): await event.wait() self.received = "done" async def normal_error_handler(update, context): self.count = 42 app.add_error_handler(async_error_handler, block=False) app.add_error_handler(normal_error_handler) app.add_handler(TypeHandler(object, self.callback_raise_error("err"), block=handler_block)) async with app: await app.start() await app.update_queue.put(self.message_update) task = asyncio.create_task(app.stop()) await asyncio.sleep(0.05) tasks = asyncio.all_tasks() assert any(":process_error:non_blocking" in t.get_name() for t in tasks) assert self.count == 42 assert self.received is None event.set() await asyncio.sleep(0.05) assert self.received == "done" assert task.done() @pytest.mark.parametrize("handler_block", [True, False]) async def test_non_blocking_error_handler_applicationhandlerstop( self, app, recwarn, handler_block ): async def callback(update, context): raise RuntimeError async def error_handler(update, context): raise ApplicationHandlerStop app.add_handler(TypeHandler(object, callback, block=handler_block)) app.add_error_handler(error_handler, block=False) async with app: await app.start() await app.update_queue.put(1) await asyncio.sleep(0.05) await app.stop() assert len(recwarn) == 1 assert recwarn[0].category is PTBUserWarning assert ( str(recwarn[0].message) == "ApplicationHandlerStop is not supported with handlers running non-blocking." ) assert ( Path(recwarn[0].filename) == PROJECT_ROOT_PATH / "telegram" / "ext" / "_application.py" ), "incorrect stacklevel!" @pytest.mark.parametrize(("block", "expected_output"), [(False, 0), (True, 5)]) async def test_default_block_error_handler(self, bot_info, block, expected_output): async def error_handler(*args, **kwargs): await asyncio.sleep(0.1) self.count = 5 bot = make_bot(bot_info, defaults=Defaults(block=block)) app = Application.builder().bot(bot).build() async with app: app.add_handler(TypeHandler(object, self.callback_raise_error("error"))) app.add_error_handler(error_handler) await app.process_update(1) await asyncio.sleep(0.05) assert self.count == expected_output await asyncio.sleep(0.1) assert self.count == 5 @pytest.mark.parametrize(("block", "expected_output"), [(False, 0), (True, 5)]) async def test_default_block_handler(self, bot_info, block, expected_output): bot = make_bot(bot_info, defaults=Defaults(block=block)) app = Application.builder().bot(bot).build() async with app: app.add_handler(TypeHandler(object, self.callback_set_count(5, sleep=0.1))) await app.process_update(1) await asyncio.sleep(0.05) assert self.count == expected_output await asyncio.sleep(0.15) assert self.count == 5 @pytest.mark.parametrize("handler_block", [True, False]) @pytest.mark.parametrize("error_handler_block", [True, False]) async def test_nonblocking_handler_raises_and_non_blocking_error_handler_raises( self, app, caplog, handler_block, error_handler_block ): handler = TypeHandler(object, self.callback_raise_error("error"), block=handler_block) app.add_handler(handler) app.add_error_handler(self.error_handler_raise_error, block=error_handler_block) async with app: await app.start() with caplog.at_level(logging.ERROR): await app.update_queue.put(1) await asyncio.sleep(0.05) assert len(caplog.records) == 1 assert caplog.records[-1].name == "telegram.ext.Application" assert ( caplog.records[-1] .getMessage() .startswith("An error was raised and an uncaught") ) # Make sure that the main loop still runs app.remove_handler(handler) app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) await app.update_queue.put(self.message_update) await asyncio.sleep(0.05) assert self.count == 1 await app.stop() @pytest.mark.parametrize( "message", [ Message(message_id=1, chat=Chat(id=2, type=None), migrate_from_chat_id=1, date=None), Message(message_id=1, chat=Chat(id=1, type=None), migrate_to_chat_id=2, date=None), Message(message_id=1, chat=Chat(id=1, type=None), date=None), None, ], ) @pytest.mark.parametrize("old_chat_id", [None, 1, "1"]) @pytest.mark.parametrize("new_chat_id", [None, 2, "1"]) def test_migrate_chat_data(self, app, message: "Message", old_chat_id: int, new_chat_id: int): def call(match: str): with pytest.raises(ValueError, match=match): app.migrate_chat_data( message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id ) if message and (old_chat_id or new_chat_id): call(r"^Message and chat_id pair are mutually exclusive$") return if not any((message, old_chat_id, new_chat_id)): call(r"^chat_id pair or message must be passed$") return if message: if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: call(r"^Invalid message instance") return effective_old_chat_id = message.migrate_from_chat_id or message.chat.id effective_new_chat_id = message.migrate_to_chat_id or message.chat.id elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): call(r"^old_chat_id and new_chat_id must be integers$") return else: effective_old_chat_id = old_chat_id effective_new_chat_id = new_chat_id app.chat_data[effective_old_chat_id]["key"] = "test" app.migrate_chat_data(message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id) assert effective_old_chat_id not in app.chat_data assert app.chat_data[effective_new_chat_id]["key"] == "test" @pytest.mark.parametrize( ("c_id", "expected"), [(321, {222: "remove_me"}), (111, {321: {"not_empty": "no"}, 222: "remove_me"})], ids=["test chat_id removal", "test no key in data (no error)"], ) def test_drop_chat_data(self, app, c_id, expected): app._chat_data.update({321: {"not_empty": "no"}, 222: "remove_me"}) app.drop_chat_data(c_id) assert app.chat_data == expected @pytest.mark.parametrize( ("u_id", "expected"), [(321, {222: "remove_me"}), (111, {321: {"not_empty": "no"}, 222: "remove_me"})], ids=["test user_id removal", "test no key in data (no error)"], ) def test_drop_user_data(self, app, u_id, expected): app._user_data.update({321: {"not_empty": "no"}, 222: "remove_me"}) app.drop_user_data(u_id) assert app.user_data == expected async def test_create_task_basic(self, app): async def callback(): await asyncio.sleep(0.05) self.count = 42 return 43 task = app.create_task(callback(), name="test_task") assert task.get_name() == "test_task" await asyncio.sleep(0.01) assert not task.done() out = await task assert task.done() assert self.count == 42 assert out == 43 @pytest.mark.parametrize("running", [True, False]) async def test_create_task_awaiting_warning(self, app, running, recwarn): async def callback(): await asyncio.sleep(0.1) return 43 async with app: if running: await app.start() task = app.create_task(callback()) if running: assert len(recwarn) == 0 assert not task.done() await app.stop() assert task.done() assert task.result() == 43 else: assert len(recwarn) == 1 assert recwarn[0].category is PTBUserWarning assert "won't be automatically awaited" in str(recwarn[0].message) assert recwarn[0].filename == __file__, "wrong stacklevel!" assert not task.done() await task @pytest.mark.parametrize("update", [None, object()]) async def test_create_task_error_handling(self, app, update): exception = RuntimeError("TestError") async def callback(): raise exception async def error(update_arg, context): self.received = update_arg, context.error app.add_error_handler(error) if update: task = app.create_task(callback(), update=update) else: task = app.create_task(callback()) with pytest.raises(RuntimeError, match="TestError"): await task assert task.exception() is exception assert isinstance(self.received, tuple) assert self.received[0] is update assert self.received[1] is exception async def test_create_task_cancel_task(self, app): async def callback(): await asyncio.sleep(5) async def error(update_arg, context): self.received = update_arg, context.error app.add_error_handler(error) async with app: await app.start() task = app.create_task(callback()) await asyncio.sleep(0.05) task.cancel() with pytest.raises(asyncio.CancelledError): await task with pytest.raises(asyncio.CancelledError): assert task.exception() # Error handlers should not be called if task was cancelled assert self.received is None # make sure that the cancelled task doesn't block the stopping of the app await app.stop() async def test_await_create_task_tasks_on_stop(self, app): event_1 = asyncio.Event() event_2 = asyncio.Event() async def callback_1(): await event_1.wait() async def callback_2(): await event_2.wait() async with app: await app.start() task_1 = app.create_task(callback_1()) task_2 = app.create_task(callback_2()) event_2.set() await task_2 assert not task_1.done() stop_task = asyncio.create_task(app.stop()) assert not stop_task.done() await asyncio.sleep(0.1) assert not stop_task.done() event_1.set() await asyncio.sleep(0.05) assert stop_task.done() async def test_create_task_awaiting_future(self, app): async def callback(): await asyncio.sleep(0.01) return 42 # `asyncio.gather` returns an `asyncio.Future` and not an # `asyncio.Task` out = await app.create_task(asyncio.gather(callback())) assert out == [42] @pytest.mark.skipif(sys.version_info >= (3, 12), reason="generator coroutines are deprecated") async def test_create_task_awaiting_generator(self, app, recwarn): event = asyncio.Event() def gen(): yield event.set() await app.create_task(gen()) assert event.is_set() assert len(recwarn) == 2 # 1st warning is: tasks not being awaited when app isn't running assert recwarn[1].category is PTBDeprecationWarning assert "Generator-based coroutines are deprecated" in str(recwarn[1].message) async def test_no_update_processor(self, app): queue = asyncio.Queue() event_1 = asyncio.Event() event_2 = asyncio.Event() await queue.put(event_1) await queue.put(event_2) async def callback(u, c): await asyncio.sleep(0.1) event = await queue.get() event.set() app.add_handler(TypeHandler(object, callback)) async with app: await app.start() await app.update_queue.put(1) await app.update_queue.put(2) assert not event_1.is_set() assert not event_2.is_set() await asyncio.sleep(0.15) assert event_1.is_set() assert not event_2.is_set() await asyncio.sleep(0.1) assert event_1.is_set() assert event_2.is_set() await app.stop() @pytest.mark.parametrize("update_processor", [15, 50, 100]) async def test_update_processor(self, one_time_bot, update_processor): # We don't test with `True` since the large number of parallel coroutines quickly leads # to test instabilities app = Application.builder().bot(one_time_bot).concurrent_updates(update_processor).build() events = { i: asyncio.Event() for i in range(app.update_processor.max_concurrent_updates + 10) } queue = asyncio.Queue() for event in events.values(): await queue.put(event) async def callback(u, c): await asyncio.sleep(0.5) (await queue.get()).set() app.add_handler(TypeHandler(object, callback)) async with app: await app.start() for i in range(app.update_processor.max_concurrent_updates + 10): await app.update_queue.put(i) for i in range(app.update_processor.max_concurrent_updates + 10): assert not events[i].is_set() await asyncio.sleep(0.9) tasks = asyncio.all_tasks() assert any(":process_concurrent_update" in task.get_name() for task in tasks) for i in range(app.update_processor.max_concurrent_updates): assert events[i].is_set() for i in range( app.update_processor.max_concurrent_updates, app.update_processor.max_concurrent_updates + 10, ): assert not events[i].is_set() await asyncio.sleep(0.5) for i in range(app.update_processor.max_concurrent_updates + 10): assert events[i].is_set() await app.stop() async def test_update_processor_done_on_shutdown(self, one_time_bot): app = Application.builder().bot(one_time_bot).concurrent_updates(True).build() event = asyncio.Event() async def callback(update, context): await event.wait() app.add_handler(TypeHandler(object, callback)) async with app: await app.start() await app.update_queue.put(1) stop_task = asyncio.create_task(app.stop()) await asyncio.sleep(0.1) assert not stop_task.done() event.set() await asyncio.sleep(0.05) assert stop_task.done() @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_basic(self, app, monkeypatch, caplog): exception_event = threading.Event() exception_testing_done = threading.Event() update_event = threading.Event() exception = TelegramError("This is a test error") assertions = {} async def get_updates(*args, **kwargs): if exception_event.is_set(): raise exception # This makes sure that other coroutines have a chance of running as well if exception_testing_done.is_set() and app.updater.running: # the longer sleep makes sure that we can exit also while get_updates is running await asyncio.sleep(20) else: await asyncio.sleep(0.01) update_event.set() return [self.message_update] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") # Check that everything's running assertions["app_running"] = app.running assertions["updater_running"] = app.updater.running assertions["job_queue_running"] = app.job_queue.scheduler.running # Check that we're getting updates update_event.wait() time.sleep(0.05) assertions["getting_updates"] = self.count == 42 # Check that errors are properly handled during polling exception_event.set() time.sleep(0.05) assertions["exception_handling"] = self.received == exception.message exception_testing_done.set() # So that the get_updates call on shutdown doesn't fail exception_event.clear() time.sleep(1) os.kill(os.getpid(), signal.SIGINT) time.sleep(0.1) # # Assert that everything has stopped running assertions["app_not_running"] = not app.running assertions["updater_not_running"] = not app.updater.running assertions["job_queue_not_running"] = not app.job_queue.scheduler.running monkeypatch.setattr(app.bot, "get_updates", get_updates) app.add_error_handler(self.error_handler_context) app.add_handler(TypeHandler(object, self.callback_set_count(42))) thread = Thread(target=thread_target) thread.start() with caplog.at_level(logging.DEBUG): app.run_polling(drop_pending_updates=True, close_loop=False) thread.join() assert len(assertions) == 8 for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" found_log = False for record in caplog.records: if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG: found_log = True assert found_log @pytest.mark.parametrize( "timeout_name", ["read_timeout", "connect_timeout", "write_timeout", "pool_timeout", "poll_interval"], ) @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_timeout_deprecation_warnings( self, timeout_name, monkeypatch, recwarn, app ): def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") time.sleep(0.05) os.kill(os.getpid(), signal.SIGINT) monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) thread = Thread(target=thread_target) thread.start() kwargs = {timeout_name: 42} app.run_polling(drop_pending_updates=True, close_loop=False, **kwargs) thread.join() if timeout_name == "poll_interval": assert len(recwarn) == 0 return assert len(recwarn) == 1 assert "Setting timeouts via `Application.run_polling` is deprecated." in str( recwarn[0].message ) assert recwarn[0].category is PTBDeprecationWarning assert recwarn[0].filename == __file__, "wrong stacklevel" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_post_init(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_init(app: Application) -> None: events.append("post_init") app = ( Application.builder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_init(post_init) .build() ) app.bot._unfreeze() monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) monkeypatch.setattr( app, "initialize", call_after(app.initialize, lambda _: events.append("init")) ) monkeypatch.setattr( app.updater, "start_polling", call_after(app.updater.start_polling, lambda _: events.append("start_polling")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) thread = Thread(target=thread_target) thread.start() app.run_polling(drop_pending_updates=True, close_loop=False) thread.join() assert events == ["init", "post_init", "start_polling"], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_post_shutdown(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_shutdown(app: Application) -> None: events.append("post_shutdown") app = ( Application.builder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_shutdown(post_shutdown) .build() ) app.bot._unfreeze() monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) monkeypatch.setattr( app, "shutdown", call_after(app.shutdown, lambda _: events.append("shutdown")) ) monkeypatch.setattr( app.updater, "shutdown", call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) thread = Thread(target=thread_target) thread.start() app.run_polling(drop_pending_updates=True, close_loop=False) thread.join() assert events == [ "updater.shutdown", "shutdown", "post_shutdown", ], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_post_stop(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_stop(app: Application) -> None: events.append("post_stop") app = ( Application.builder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_stop(post_stop) .build() ) app.bot._unfreeze() monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) monkeypatch.setattr(app, "stop", call_after(app.stop, lambda _: events.append("stop"))) monkeypatch.setattr( app.updater, "stop", call_after(app.updater.stop, lambda _: events.append("updater.stop")), ) monkeypatch.setattr( app.updater, "shutdown", call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) thread = Thread(target=thread_target) thread.start() app.run_polling(drop_pending_updates=True, close_loop=False) thread.join() assert events == [ "updater.stop", "stop", "post_stop", "updater.shutdown", ], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_parameters_passing(self, app, monkeypatch): # First check that the default values match and that we have all arguments there updater_signature = inspect.signature(app.updater.start_polling) app_signature = inspect.signature(app.run_polling) for name, param in updater_signature.parameters.items(): if name == "error_callback": assert name not in app_signature.parameters continue assert name in app_signature.parameters assert param.kind == app_signature.parameters[name].kind assert param.default == app_signature.parameters[name].default # Check that we pass them correctly async def start_polling(_, **kwargs): self.received = kwargs return True async def stop(_, **kwargs): return True def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") time.sleep(0.1) os.kill(os.getpid(), signal.SIGINT) monkeypatch.setattr(Updater, "start_polling", start_polling) monkeypatch.setattr(Updater, "stop", stop) thread = Thread(target=thread_target) thread.start() app.run_polling(close_loop=False) thread.join() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) for name, param in updater_signature.parameters.items(): if name == "error_callback": assert self.received[name] is not None else: assert self.received[name] == param.default expected = { name: name for name in updater_signature.parameters if name != "error_callback" } thread = Thread(target=thread_target) thread.start() app.run_polling(close_loop=False, **expected) thread.join() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) assert self.received.pop("error_callback", None) assert self.received == expected @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_basic(self, app, monkeypatch, caplog): assertions = {} def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") # Check that everything's running assertions["app_running"] = app.running assertions["updater_running"] = app.updater.running assertions["job_queue_running"] = app.job_queue.scheduler.running # Check that we're getting updates loop = asyncio.new_event_loop() loop.run_until_complete( send_webhook_message(ip, port, self.message_update.to_json(), "TOKEN") ) loop.close() time.sleep(0.05) assertions["getting_updates"] = self.count == 42 os.kill(os.getpid(), signal.SIGINT) time.sleep(0.1) # # Assert that everything has stopped running assertions["app_not_running"] = not app.running assertions["updater_not_running"] = not app.updater.running assertions["job_queue_not_running"] = not app.job_queue.scheduler.running app.add_handler(TypeHandler(object, self.callback_set_count(42))) thread = Thread(target=thread_target) thread.start() ip = "127.0.0.1" port = randrange(1024, 49152) with caplog.at_level(logging.DEBUG): app.run_webhook( ip_address=ip, port=port, url_path="TOKEN", drop_pending_updates=True, close_loop=False, ) thread.join() assert len(assertions) == 7 for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" found_log = False for record in caplog.records: if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG: found_log = True assert found_log @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_post_init(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_init(app: Application) -> None: events.append("post_init") app = ( Application.builder() .post_init(post_init) .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .build() ) app.bot._unfreeze() monkeypatch.setattr( app, "initialize", call_after(app.initialize, lambda _: events.append("init")) ) monkeypatch.setattr( app.updater, "start_webhook", call_after(app.updater.start_webhook, lambda _: events.append("start_webhook")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) monkeypatch.setattr(app.bot, "set_webhook", return_true) thread = Thread(target=thread_target) thread.start() ip = "127.0.0.1" port = randrange(1024, 49152) app.run_webhook( ip_address=ip, port=port, url_path="TOKEN", drop_pending_updates=True, close_loop=False, ) thread.join() assert events == ["init", "post_init", "start_webhook"], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_post_shutdown(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_shutdown(app: Application) -> None: events.append("post_shutdown") app = ( Application.builder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_shutdown(post_shutdown) .build() ) app.bot._unfreeze() monkeypatch.setattr( app, "shutdown", call_after(app.shutdown, lambda _: events.append("shutdown")) ) monkeypatch.setattr( app.updater, "shutdown", call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) monkeypatch.setattr(app.bot, "set_webhook", return_true) thread = Thread(target=thread_target) thread.start() ip = "127.0.0.1" port = randrange(1024, 49152) app.run_webhook( ip_address=ip, port=port, url_path="TOKEN", drop_pending_updates=True, close_loop=False, ) thread.join() assert events == [ "updater.shutdown", "shutdown", "post_shutdown", ], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_post_stop(self, one_time_bot, monkeypatch): events = [] def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") os.kill(os.getpid(), signal.SIGINT) async def post_stop(app: Application) -> None: events.append("post_stop") app = ( Application.builder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_stop(post_stop) .build() ) app.bot._unfreeze() monkeypatch.setattr(app, "stop", call_after(app.stop, lambda _: events.append("stop"))) monkeypatch.setattr( app.updater, "stop", call_after(app.updater.stop, lambda _: events.append("updater.stop")), ) monkeypatch.setattr( app.updater, "shutdown", call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), ) monkeypatch.setattr(app.bot, "delete_webhook", return_true) monkeypatch.setattr(app.bot, "set_webhook", return_true) thread = Thread(target=thread_target) thread.start() ip = "127.0.0.1" port = randrange(1024, 49152) app.run_webhook( ip_address=ip, port=port, url_path="TOKEN", drop_pending_updates=True, close_loop=False, ) thread.join() assert events == [ "updater.stop", "stop", "post_stop", "updater.shutdown", ], "Wrong order of events detected!" @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_parameters_passing(self, one_time_bot, monkeypatch): # Check that we pass them correctly async def start_webhook(_, **kwargs): self.received = kwargs return True async def stop(_, **kwargs): return True # First check that the default values match and that we have all arguments there updater_signature = inspect.signature(Updater.start_webhook) monkeypatch.setattr(Updater, "start_webhook", start_webhook) monkeypatch.setattr(Updater, "stop", stop) app = ApplicationBuilder().bot(one_time_bot).build() app_signature = inspect.signature(app.run_webhook) for name, param in updater_signature.parameters.items(): if name == "self": continue assert name in app_signature.parameters assert param.kind == app_signature.parameters[name].kind assert param.default == app_signature.parameters[name].default def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") time.sleep(0.1) os.kill(os.getpid(), signal.SIGINT) thread = Thread(target=thread_target) thread.start() app.run_webhook(close_loop=False) thread.join() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) - {"self"} for name, param in updater_signature.parameters.items(): if name == "self": continue assert self.received[name] == param.default expected = {name: name for name in updater_signature.parameters if name != "self"} thread = Thread(target=thread_target) thread.start() app.run_webhook(close_loop=False, **expected) thread.join() assert set(self.received.keys()) == set(expected.keys()) assert self.received == expected @pytest.mark.parametrize("exception", [SystemExit, KeyboardInterrupt]) def test_raise_system_exit_keyboard_interrupt_post_init( self, one_time_bot, monkeypatch, exception ): async def post_init(application): raise exception called_callbacks = set() async def callback(*args, **kwargs): called_callbacks.add(kwargs["name"]) for cls, method, entry in [ (Application, "initialize", "app_initialize"), (Application, "start", "app_start"), (Application, "stop", "app_stop"), (Application, "shutdown", "app_shutdown"), (Updater, "initialize", "updater_initialize"), (Updater, "shutdown", "updater_shutdown"), (Updater, "stop", "updater_stop"), (Updater, "start_polling", "updater_start_polling"), ]: def after(_, name): called_callbacks.add(name) monkeypatch.setattr( cls, method, call_after(getattr(cls, method), functools.partial(after, name=entry)), ) app = ( ApplicationBuilder() .bot(one_time_bot) .post_init(post_init) .post_stop(functools.partial(callback, name="post_stop")) .post_shutdown(functools.partial(callback, name="post_shutdown")) .build() ) app.run_polling(close_loop=False) # This checks two things: # 1. start/stop are *not* called! # 2. we do have a graceful shutdown assert called_callbacks == { "app_initialize", "updater_initialize", "app_shutdown", "post_shutdown", "updater_shutdown", } @pytest.mark.parametrize("exception", [SystemExit("PTBTest"), KeyboardInterrupt("PTBTest")]) @pytest.mark.parametrize("kind", ["handler", "error_handler", "job"]) # @pytest.mark.parametrize("block", [True, False]) # Testing with block=False would be nice but that doesn't work well with pytest for some reason # in any case, block=False is the simpler behavior since it is roughly similar to what happens # when you hit CTRL+C in the commandline. def test_raise_system_exit_keyboard_jobs_handlers( self, one_time_bot, monkeypatch, exception, kind, caplog ): async def queue_and_raise(application): await application.update_queue.put("will_not_be_processed") raise exception async def handler_callback(update, context): if kind == "handler": await queue_and_raise(context.application) elif kind == "error_handler": raise TelegramError("Triggering error callback") async def error_callback(update, context): await queue_and_raise(context.application) async def job_callback(context): await queue_and_raise(context.application) async def enqueue_update(): await asyncio.sleep(0.5) await app.update_queue.put(1) async def post_init(application): if kind == "job": application.job_queue.run_once(when=0.5, callback=job_callback) else: app.create_task(enqueue_update()) async def update_logger_callback(update, context): context.bot_data.setdefault("processed_updates", set()).add(update) called_callbacks = set() async def callback(*args, **kwargs): called_callbacks.add(kwargs["name"]) for cls, method, entry in [ (Application, "initialize", "app_initialize"), (Application, "start", "app_start"), (Application, "stop", "app_stop"), (Application, "shutdown", "app_shutdown"), (Updater, "initialize", "updater_initialize"), (Updater, "shutdown", "updater_shutdown"), (Updater, "stop", "updater_stop"), (Updater, "start_polling", "updater_start_polling"), ]: def after(_, name): called_callbacks.add(name) monkeypatch.setattr( cls, method, call_after(getattr(cls, method), functools.partial(after, name=entry)), ) app = ( ApplicationBuilder() .bot(one_time_bot) .post_init(post_init) .post_stop(functools.partial(callback, name="post_stop")) .post_shutdown(functools.partial(callback, name="post_shutdown")) .build() ) monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) monkeypatch.setattr(app.bot, "delete_webhook", return_true) app.add_handler(TypeHandler(object, update_logger_callback), group=-10) app.add_handler(TypeHandler(object, handler_callback)) app.add_error_handler(error_callback) with caplog.at_level(logging.DEBUG): app.run_polling(close_loop=False) # This checks that we have a clean shutdown even when the user raises SystemExit # or KeyboardInterrupt in a handler/error handler/job callback assert called_callbacks == { "app_initialize", "app_shutdown", "app_start", "app_stop", "post_shutdown", "post_stop", "updater_initialize", "updater_shutdown", "updater_start_polling", "updater_stop", } # These next checks make sure that the update queue is properly cleaned even if there are # still pending updates in the queue # Unfortunately this is apparently extremely hard to get right with jobs, so we're # skipping that case for the sake of simplicity if kind == "job": return found = False for record in caplog.records: if record.getMessage() != "Dropping pending update: will_not_be_processed": continue assert record.name == "telegram.ext.Application" assert record.levelno == logging.DEBUG found = True assert found, "`Dropping pending updates` message not found in logs!" assert "will_not_be_processed" not in app.bot_data.get("processed_updates", set()) def test_run_without_updater(self, one_time_bot): app = ApplicationBuilder().bot(one_time_bot).updater(None).build() with pytest.raises(RuntimeError, match="only available if the application has an Updater"): app.run_webhook() with pytest.raises(RuntimeError, match="only available if the application has an Updater"): app.run_polling() @pytest.mark.parametrize("method", ["start", "initialize"]) @pytest.mark.filterwarnings("ignore::telegram.warnings.PTBUserWarning") def test_run_error_in_application(self, one_time_bot, monkeypatch, method): shutdowns = [] async def raise_method(*args, **kwargs): raise RuntimeError("Test Exception") def after_shutdown(name): def _after_shutdown(*args, **kwargs): shutdowns.append(name) return _after_shutdown monkeypatch.setattr(Application, method, raise_method) monkeypatch.setattr( Application, "shutdown", call_after(Application.shutdown, after_shutdown("application")), ) monkeypatch.setattr( Updater, "shutdown", call_after(Updater.shutdown, after_shutdown("updater")) ) app = ApplicationBuilder().bot(one_time_bot).build() monkeypatch.setattr(app.bot, "delete_webhook", return_true) monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) with pytest.raises(RuntimeError, match="Test Exception"): app.run_polling(close_loop=False) assert not app.running assert not app.updater.running if method == "initialize": # If App.initialize fails, then App.shutdown pretty much does nothing, especially # doesn't call Updater.shutdown. assert set(shutdowns) == {"application"} else: assert set(shutdowns) == {"application", "updater"} @pytest.mark.parametrize("method", ["start_polling", "start_webhook"]) @pytest.mark.filterwarnings("ignore::telegram.warnings.PTBUserWarning") def test_run_error_in_updater(self, one_time_bot, monkeypatch, method): shutdowns = [] async def raise_method(*args, **kwargs): raise RuntimeError("Test Exception") def after_shutdown(name): def _after_shutdown(*args, **kwargs): shutdowns.append(name) return _after_shutdown monkeypatch.setattr(Updater, method, raise_method) monkeypatch.setattr( Application, "shutdown", call_after(Application.shutdown, after_shutdown("application")), ) monkeypatch.setattr( Updater, "shutdown", call_after(Updater.shutdown, after_shutdown("updater")) ) app = ApplicationBuilder().bot(one_time_bot).build() with pytest.raises(RuntimeError, match="Test Exception"): # noqa: PT012 if "polling" in method: app.run_polling(close_loop=False) else: app.run_webhook(close_loop=False) assert not app.running assert not app.updater.running assert set(shutdowns) == {"application", "updater"} @pytest.mark.skipif( platform.system() != "Windows", reason="Only really relevant on windows", ) @pytest.mark.parametrize("method", ["start_polling", "start_webhook"]) def test_run_stop_signal_warning_windows(self, one_time_bot, method, recwarn, monkeypatch): async def raise_method(*args, **kwargs): raise RuntimeError("Prevent Actually Running") monkeypatch.setattr(Application, "initialize", raise_method) app = ApplicationBuilder().bot(one_time_bot).build() with pytest.raises(RuntimeError, match="Prevent Actually Running"): # noqa: PT012 if "polling" in method: app.run_polling(close_loop=False, stop_signals=(signal.SIGINT,)) else: app.run_webhook(close_loop=False, stop_signals=(signal.SIGTERM,)) assert len(recwarn) >= 1 found = False for record in recwarn: print(record) if str(record.message).startswith("Could not add signal handlers for the stop"): assert record.category is PTBUserWarning assert record.filename == __file__, "stacklevel is incorrect!" found = True assert found recwarn.clear() with pytest.raises(RuntimeError, match="Prevent Actually Running"): # noqa: PT012 if "polling" in method: app.run_polling(close_loop=False, stop_signals=None) else: app.run_webhook(close_loop=False, stop_signals=None) for record in recwarn: assert not str(record.message).startswith("Could not add signal handlers for the stop") @pytest.mark.flaky(3, 1) # loop.call_later will error the test when a flood error is received def test_signal_handlers(self, app, monkeypatch): # this test should make sure that signal handlers are set by default on Linux + Mac, # and not on Windows. received_signals = [] def signal_handler_test(*args, **kwargs): # args[0] is the signal, [1] the callback received_signals.append(args[0]) loop = asyncio.get_event_loop() monkeypatch.setattr(loop, "add_signal_handler", signal_handler_test) monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) def abort_app(): raise SystemExit loop.call_later(0.6, abort_app) app.run_polling(close_loop=False) if platform.system() == "Windows": assert received_signals == [] else: assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] received_signals.clear() loop.call_later(0.8, abort_app) app.run_webhook(port=49152, webhook_url="example.com", close_loop=False) if platform.system() == "Windows": assert received_signals == [] else: assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] def test_stop_running_not_running(self, app, caplog): with caplog.at_level(logging.DEBUG): app.stop_running() assert len(caplog.records) == 1 assert caplog.records[-1].name == "telegram.ext.Application" assert caplog.records[-1].getMessage().endswith("`stop_running()` likely has no effect.") def test_stop_running_post_init(self, app, monkeypatch, caplog, one_time_bot): async def post_init(app): app.stop_running() called_callbacks = [] async def callback(*args, **kwargs): called_callbacks.append(kwargs["name"]) monkeypatch.setattr(Application, "start", functools.partial(callback, name="start")) monkeypatch.setattr( Updater, "start_polling", functools.partialmethod(callback, name="start_polling") ) app = ( ApplicationBuilder() .bot(one_time_bot) .post_init(post_init) .post_stop(functools.partial(callback, name="post_stop")) .post_shutdown(functools.partial(callback, name="post_shutdown")) .build() ) with caplog.at_level(logging.INFO): app.run_polling(close_loop=False) # The important part here is that start(_polling) are *not* called! # post_stop must not be called either, since we never called stop() assert called_callbacks == ["post_shutdown"] assert len(caplog.records) == 1 assert caplog.records[-1].name == "telegram.ext.Application" assert ( "Application received stop signal via `stop_running`" in caplog.records[-1].getMessage() ) @pytest.mark.parametrize("method", ["polling", "webhook"]) def test_stop_running(self, one_time_bot, monkeypatch, method): # asyncio.Event() seems to be hard to use across different threads (awaiting in main # thread, setting in another thread), so we use threading.Event() instead. # This requires the use of run_in_executor, but that's fine. put_update_event = threading.Event() callback_done_event = threading.Event() called_stop_running = threading.Event() assertions = {} async def post_init(app): # Simply calling app.update_queue.put_nowait(method) in the thread_target doesn't work # for some reason (probably threading magic), so we use an event from the thread_target # to put the update into the queue in the main thread. async def task(app): await asyncio.get_running_loop().run_in_executor(None, put_update_event.wait) await app.update_queue.put(method) app.create_task(task(app)) app = ( ApplicationBuilder() .application_class(PytestApplication) .updater(PytestUpdater(one_time_bot, asyncio.Queue())) .post_init(post_init) .build() ) monkeypatch.setattr(app.bot, "get_updates", empty_get_updates) events = [] monkeypatch.setattr( app.updater, "stop", call_after(app.updater.stop, lambda _: events.append("updater.stop")), ) monkeypatch.setattr( app, "stop", call_after(app.stop, lambda _: events.append("app.stop")), ) monkeypatch.setattr( app, "shutdown", call_after(app.shutdown, lambda _: events.append("app.shutdown")), ) monkeypatch.setattr(app.bot, "set_webhook", return_true) monkeypatch.setattr(app.bot, "delete_webhook", return_true) def thread_target(): waited = 0 while not app.running: time.sleep(0.05) waited += 0.05 if waited > 5: pytest.fail("App apparently won't start") time.sleep(0.1) assertions["called_stop_running_not_set"] = not called_stop_running.is_set() put_update_event.set() time.sleep(0.1) assertions["called_stop_running_set"] = called_stop_running.is_set() # App should have entered `stop` now but not finished it yet because the callback # is still running assertions["updater.stop_event"] = events == ["updater.stop"] assertions["app.running_False"] = not app.running callback_done_event.set() time.sleep(0.1) # Now that the update is fully handled, we expect the full shutdown assertions["events"] = events == ["updater.stop", "app.stop", "app.shutdown"] async def callback(update, context): context.application.stop_running() called_stop_running.set() await asyncio.get_running_loop().run_in_executor(None, callback_done_event.wait) app.add_handler(TypeHandler(object, callback)) thread = Thread(target=thread_target) thread.start() if method == "polling": app.run_polling(close_loop=False, drop_pending_updates=True) else: ip = "127.0.0.1" port = randrange(1024, 49152) app.run_webhook( ip_address=ip, port=port, url_path="TOKEN", drop_pending_updates=False, close_loop=False, ) thread.join() assert len(assertions) == 5 for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" async def test_process_update_exception_in_building_context(self, monkeypatch, caplog, app): # Makes sure that exceptions in building the context don't stop the application exception = ValueError("TestException") original_from_update = CallbackContext.from_update def raise_exception(update, application): if update == 1: raise exception return original_from_update(update, application) monkeypatch.setattr(CallbackContext, "from_update", raise_exception) received_updates = set() async def callback(update, context): received_updates.add(update) app.add_handler(TypeHandler(int, callback)) async with app: with caplog.at_level(logging.CRITICAL): await app.process_update(1) assert received_updates == set() assert len(caplog.records) == 1 record = caplog.records[0] assert record.name == "telegram.ext.Application" assert record.getMessage().startswith( "Error while building CallbackContext for update 1" ) assert record.levelno == logging.CRITICAL # Let's also check that no critical log is produced when the exception is not raised caplog.clear() with caplog.at_level(logging.CRITICAL): await app.process_update(2) assert received_updates == {2} assert len(caplog.records) == 0 async def test_process_error_exception_in_building_context(self, monkeypatch, caplog, app): # Makes sure that exceptions in building the context don't stop the application exception = ValueError("TestException") original_from_error = CallbackContext.from_error def raise_exception(update, error, application, *args, **kwargs): if error == 1: raise exception return original_from_error(update, error, application, *args, **kwargs) monkeypatch.setattr(CallbackContext, "from_error", raise_exception) received_errors = set() async def callback(update, context): received_errors.add(context.error) app.add_error_handler(callback) async with app: with caplog.at_level(logging.CRITICAL): await app.process_error(update=None, error=1) assert received_errors == set() assert len(caplog.records) == 1 record = caplog.records[0] assert record.name == "telegram.ext.Application" assert record.getMessage().startswith( "Error while building CallbackContext for exception 1" ) assert record.levelno == logging.CRITICAL # Let's also check that no critical log is produced when the exception is not raised caplog.clear() with caplog.at_level(logging.CRITICAL): await app.process_error(update=None, error=2) assert received_errors == {2} assert len(caplog.records) == 0