#!/usr/bin/env python # # A library that provides a Python interface to the Telegram Bot API # Copyright (C) 2015-2025 # Leandro Toledo de Souza # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser Public License for more details. # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """ We mostly test on directly on AIORateLimiter here, b/c BaseRateLimiter doesn't contain anything notable """ import asyncio import datetime as dtm import json import platform import time from collections import Counter from http import HTTPStatus import pytest from telegram import BotCommand, Chat, Message, User from telegram.constants import ParseMode from telegram.error import RetryAfter from telegram.ext import AIORateLimiter, BaseRateLimiter, Defaults, ExtBot from telegram.request import BaseRequest, RequestData from tests.auxil.envvars import GITHUB_ACTIONS, TEST_WITH_OPT_DEPS @pytest.mark.skipif( TEST_WITH_OPT_DEPS, reason="Only relevant if the optional dependency is not installed" ) class TestNoRateLimiter: def test_init(self): with pytest.raises(RuntimeError, match=r"python-telegram-bot\[rate-limiter\]"): AIORateLimiter() class TestBaseRateLimiter: rl_received = None request_received = None async def test_no_rate_limiter(self, bot): with pytest.raises(ValueError, match="if a `ExtBot.rate_limiter` is set"): await bot.send_message(chat_id=42, text="test", rate_limit_args="something") async def test_argument_passing(self, bot_info, monkeypatch, bot): class TestRateLimiter(BaseRateLimiter): async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def process_request( self, callback, args, kwargs, endpoint, data, rate_limit_args, ): if TestBaseRateLimiter.rl_received is None: TestBaseRateLimiter.rl_received = [] TestBaseRateLimiter.rl_received.append((endpoint, data, rate_limit_args)) return await callback(*args, **kwargs) class TestRequest(BaseRequest): async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def do_request(self, *args, **kwargs): if TestBaseRateLimiter.request_received is None: TestBaseRateLimiter.request_received = [] TestBaseRateLimiter.request_received.append((args, kwargs)) # return bot.bot.to_dict() for the `get_me` call in `Bot.initialize` return 200, json.dumps({"ok": True, "result": bot.bot.to_dict()}).encode() defaults = Defaults(parse_mode=ParseMode.HTML) test_request = TestRequest() standard_bot = ExtBot(token=bot.token, defaults=defaults, request=test_request) rl_bot = ExtBot( token=bot.token, defaults=defaults, request=test_request, rate_limiter=TestRateLimiter(), ) async with standard_bot: await standard_bot.set_my_commands( commands=[BotCommand("test", "test")], language_code="en", api_kwargs={"api": "kwargs"}, ) async with rl_bot: await rl_bot.set_my_commands( commands=[BotCommand("test", "test")], language_code="en", rate_limit_args=(43, "test-1"), api_kwargs={"api": "kwargs"}, ) assert len(self.rl_received) == 2 assert self.rl_received[0] == ("getMe", {}, None) assert self.rl_received[1] == ( "setMyCommands", {"commands": [BotCommand("test", "test")], "language_code": "en", "api": "kwargs"}, (43, "test-1"), ) assert len(self.request_received) == 4 # self.request_received[i] = i-th received request # self.request_received[i][0] = i-th received request's args # self.request_received[i][1] = i-th received request's kwargs assert self.request_received[0][1]["url"].endswith("getMe") assert self.request_received[2][1]["url"].endswith("getMe") assert self.request_received[1][0] == self.request_received[3][0] assert self.request_received[1][1].keys() == self.request_received[3][1].keys() for key, value in self.request_received[1][1].items(): if isinstance(value, RequestData): assert value.parameters == self.request_received[3][1][key].parameters assert value.parameters["api"] == "kwargs" else: assert value == self.request_received[3][1][key] @pytest.mark.skipif( not TEST_WITH_OPT_DEPS, reason="Only relevant if the optional dependency is installed" ) @pytest.mark.skipif( GITHUB_ACTIONS and platform.system() == "Darwin", reason="The timings are apparently rather inaccurate on MacOS.", ) @pytest.mark.flaky(10, 1) # Timings aren't quite perfect class TestAIORateLimiter: count = 0 apb_count = 0 call_times = [] apb_call_times = [] class CountRequest(BaseRequest): def __init__(self, retry_after=None): self.retry_after = retry_after async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def do_request(self, *args, **kwargs): request_data = kwargs.get("request_data") allow_paid_broadcast = request_data.parameters.get("allow_paid_broadcast", False) if allow_paid_broadcast: TestAIORateLimiter.apb_count += 1 TestAIORateLimiter.apb_call_times.append(time.time()) else: TestAIORateLimiter.count += 1 TestAIORateLimiter.call_times.append(time.time()) if self.retry_after: raise RetryAfter(retry_after=1) url = kwargs.get("url").lower() if url.endswith("getme"): return ( HTTPStatus.OK, json.dumps( {"ok": True, "result": User(id=1, first_name="bot", is_bot=True).to_dict()} ).encode(), ) if url.endswith("sendmessage"): return ( HTTPStatus.OK, json.dumps( { "ok": True, "result": Message( message_id=1, date=dtm.datetime.now(), chat=Chat(1, "chat") ).to_dict(), } ).encode(), ) return None @pytest.fixture(autouse=True) def _reset(self): TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] TestAIORateLimiter.apb_count = 0 TestAIORateLimiter.apb_call_times = [] @pytest.mark.parametrize("max_retries", [0, 1, 4]) async def test_max_retries(self, bot, max_retries): bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=1), rate_limiter=AIORateLimiter( max_retries=max_retries, overall_max_rate=0, group_max_rate=0 ), ) with pytest.raises(RetryAfter): await bot.get_me() # Check that we retried the request the correct number of times assert TestAIORateLimiter.count == max_retries + 1 # Check that the retries were delayed correctly times = TestAIORateLimiter.call_times if len(times) <= 1: return delays = [j - i for i, j in zip(times[:-1], times[1:])] assert delays == pytest.approx([1.1 for _ in range(max_retries)], rel=0.05) async def test_delay_all_pending_on_retry(self, bot): # Makes sure that a RetryAfter blocks *all* pending requests bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=1), rate_limiter=AIORateLimiter(max_retries=1, overall_max_rate=0, group_max_rate=0), ) task_1 = asyncio.create_task(bot.get_me()) await asyncio.sleep(0.1) task_2 = asyncio.create_task(bot.get_me()) assert not task_1.done() assert not task_2.done() await asyncio.sleep(1.1) assert isinstance(task_1.exception(), RetryAfter) assert not task_2.done() await asyncio.sleep(1.1) assert isinstance(task_2.exception(), RetryAfter) @pytest.mark.parametrize("group_id", [-1, "-1", "@username"]) @pytest.mark.parametrize("chat_id", [1, "1"]) async def test_basic_rate_limiting(self, bot, group_id, chat_id): try: rl_bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=None), rate_limiter=AIORateLimiter( overall_max_rate=1, overall_time_period=1 / 4, group_max_rate=1, group_time_period=1 / 2, ), ) async with rl_bot: non_group_tasks = {} group_tasks = {} for i in range(4): group_tasks[i] = asyncio.create_task( rl_bot.send_message(chat_id=group_id, text="test") ) for i in range(8): non_group_tasks[i] = asyncio.create_task( rl_bot.send_message(chat_id=chat_id, text="test") ) await asyncio.sleep(0.85) # We expect 5 requests: # 1: `get_me` from `async with rl_bot` # 2: `send_message` at time 0.00 # 3: `send_message` at time 0.25 # 4: `send_message` at time 0.50 # 5: `send_message` at time 0.75 assert TestAIORateLimiter.count == 5 assert sum(1 for task in non_group_tasks.values() if task.done()) < 8 assert sum(1 for task in group_tasks.values() if task.done()) < 4 # 3 seconds after start await asyncio.sleep(3.1 - 0.85) assert all(task.done() for task in non_group_tasks.values()) assert all(task.done() for task in group_tasks.values()) finally: # cleanup await asyncio.gather(*non_group_tasks.values(), *group_tasks.values()) TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] async def test_rate_limiting_no_chat_id(self, bot): try: rl_bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=None), rate_limiter=AIORateLimiter( overall_max_rate=1, overall_time_period=1 / 2, ), ) async with rl_bot: non_chat_tasks = {} chat_tasks = {} for i in range(4): chat_tasks[i] = asyncio.create_task( rl_bot.send_message(chat_id=-1, text="test") ) for i in range(8): non_chat_tasks[i] = asyncio.create_task(rl_bot.get_me()) await asyncio.sleep(0.6) # We expect 11 requests: # 1: `get_me` from `async with rl_bot` # 2: `send_message` at time 0.00 # 3: `send_message` at time 0.05 # 4: 8 times `get_me` assert TestAIORateLimiter.count == 11 assert sum(1 for task in non_chat_tasks.values() if task.done()) == 8 assert sum(1 for task in chat_tasks.values() if task.done()) == 2 # 1.6 seconds after start await asyncio.sleep(1.6 - 0.6) assert all(task.done() for task in non_chat_tasks.values()) assert all(task.done() for task in chat_tasks.values()) finally: # cleanup await asyncio.gather(*non_chat_tasks.values(), *chat_tasks.values()) TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] @pytest.mark.parametrize("intermediate", [True, False]) async def test_group_caching(self, bot, intermediate): try: max_rate = 1000 rl_bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=None), rate_limiter=AIORateLimiter( overall_max_rate=max_rate, overall_time_period=1, group_max_rate=max_rate, group_time_period=1, ), ) # Unfortunately, there is no reliable way to test this without checking the internals assert len(rl_bot.rate_limiter._group_limiters) == 0 await asyncio.gather( *(rl_bot.send_message(chat_id=-(i + 1), text=f"{i}") for i in range(513)) ) if intermediate: await rl_bot.send_message(chat_id=-1, text="999") assert 1 <= len(rl_bot.rate_limiter._group_limiters) <= 513 else: await asyncio.sleep(1) await rl_bot.send_message(chat_id=-1, text="999") assert len(rl_bot.rate_limiter._group_limiters) == 1 finally: TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] async def test_allow_paid_broadcast(self, bot): try: rl_bot = ExtBot( token=bot.token, request=self.CountRequest(retry_after=None), rate_limiter=AIORateLimiter(), ) async with rl_bot: apb_tasks = {} non_apb_tasks = {} for i in range(3000): apb_tasks[i] = asyncio.create_task( rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=True) ) number = 2 for i in range(number): non_apb_tasks[i] = asyncio.create_task( rl_bot.send_message(chat_id=-1, text="test") ) non_apb_tasks[i + number] = asyncio.create_task( rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=False) ) await asyncio.sleep(0.1) # We expect 5 non-apb requests: # 1: `get_me` from `async with rl_bot` # 2-5: `send_message` assert TestAIORateLimiter.count == 5 assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4 # ~2 second after start # We do the checks once all apb_tasks are done as apparently getting the timings # right to check after 1 second is hard await asyncio.sleep(2.1 - 0.1) assert all(task.done() for task in apb_tasks.values()) apb_call_times = [ ct - TestAIORateLimiter.apb_call_times[0] for ct in TestAIORateLimiter.apb_call_times ] apb_call_times_dict = Counter(map(int, apb_call_times)) # We expect ~2000 apb requests after the first second # 2000 (>>1000), since we have a floating window logic such that an initial # burst is allowed that is hard to measure in the tests assert apb_call_times_dict[0] <= 2000 assert apb_call_times_dict[0] + apb_call_times_dict[1] < 3000 assert sum(apb_call_times_dict.values()) == 3000 finally: # cleanup await asyncio.gather(*apb_tasks.values(), *non_apb_tasks.values())