mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-07 02:59:57 +01:00
c440c255a7
Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com>
937 lines
31 KiB
Python
937 lines
31 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# A library that provides a Python interface to the Telegram Bot API
|
|
# Copyright (C) 2015-2022
|
|
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser Public License
|
|
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
|
import logging
|
|
from queue import Queue
|
|
from threading import current_thread
|
|
from time import sleep
|
|
|
|
import pytest
|
|
|
|
from telegram import Message, User, Chat, Update, Bot, MessageEntity
|
|
from telegram.ext import (
|
|
MessageHandler,
|
|
Filters,
|
|
Defaults,
|
|
CommandHandler,
|
|
CallbackContext,
|
|
JobQueue,
|
|
BasePersistence,
|
|
ContextTypes,
|
|
)
|
|
from telegram.ext import PersistenceInput
|
|
from telegram.ext.dispatcher import Dispatcher, DispatcherHandlerStop
|
|
from telegram.utils.defaultvalue import DEFAULT_FALSE
|
|
from telegram.error import TelegramError
|
|
from tests.conftest import create_dp
|
|
from collections import defaultdict
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def dp2(bot):
|
|
yield from create_dp(bot)
|
|
|
|
|
|
class CustomContext(CallbackContext):
|
|
pass
|
|
|
|
|
|
class TestDispatcher:
|
|
message_update = Update(
|
|
1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
|
|
)
|
|
received = None
|
|
count = 0
|
|
|
|
@pytest.fixture(autouse=True, name='reset')
|
|
def reset_fixture(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.received = None
|
|
self.count = 0
|
|
|
|
def error_handler_context(self, update, context):
|
|
self.received = context.error.message
|
|
|
|
def error_handler_raise_error(self, update, context):
|
|
raise Exception('Failing bigly')
|
|
|
|
def callback_increase_count(self, update, context):
|
|
self.count += 1
|
|
|
|
def callback_set_count(self, count):
|
|
def callback(bot, update):
|
|
self.count = count
|
|
|
|
return callback
|
|
|
|
def callback_raise_error(self, update, context):
|
|
raise TelegramError(update.message.text)
|
|
|
|
def callback_received(self, update, context):
|
|
self.received = update.message
|
|
|
|
def callback_context(self, update, context):
|
|
if (
|
|
isinstance(context, CallbackContext)
|
|
and isinstance(context.bot, Bot)
|
|
and isinstance(context.update_queue, Queue)
|
|
and isinstance(context.job_queue, JobQueue)
|
|
and isinstance(context.error, TelegramError)
|
|
):
|
|
self.received = context.error.message
|
|
|
|
def test_slot_behaviour(self, bot, mro_slots):
|
|
dp = Dispatcher(bot=bot, update_queue=None)
|
|
for at in dp.__slots__:
|
|
at = f"_Dispatcher{at}" if at.startswith('__') and not at.endswith('__') else at
|
|
assert getattr(dp, at, 'err') != 'err', f"got extra slot '{at}'"
|
|
assert len(mro_slots(dp)) == len(set(mro_slots(dp))), "duplicate slot"
|
|
|
|
def test_less_than_one_worker_warning(self, dp, recwarn):
|
|
Dispatcher(dp.bot, dp.update_queue, job_queue=dp.job_queue, workers=0)
|
|
assert len(recwarn) == 1
|
|
assert (
|
|
str(recwarn[0].message)
|
|
== 'Asynchronous callbacks can not be processed without at least one worker thread.'
|
|
)
|
|
assert recwarn[0].filename == __file__, "stacklevel is incorrect!"
|
|
|
|
def test_one_context_per_update(self, dp):
|
|
def one(update, context):
|
|
if update.message.text == 'test':
|
|
context.my_flag = True
|
|
|
|
def two(update, context):
|
|
if update.message.text == 'test':
|
|
if not hasattr(context, 'my_flag'):
|
|
pytest.fail()
|
|
else:
|
|
if hasattr(context, 'my_flag'):
|
|
pytest.fail()
|
|
|
|
dp.add_handler(MessageHandler(Filters.regex('test'), one), group=1)
|
|
dp.add_handler(MessageHandler(None, two), group=2)
|
|
u = Update(1, Message(1, None, None, None, text='test'))
|
|
dp.process_update(u)
|
|
u.message.text = 'something'
|
|
dp.process_update(u)
|
|
|
|
def test_error_handler(self, dp):
|
|
dp.add_error_handler(self.error_handler_context)
|
|
error = TelegramError('Unauthorized.')
|
|
dp.update_queue.put(error)
|
|
sleep(0.1)
|
|
assert self.received == 'Unauthorized.'
|
|
|
|
# Remove handler
|
|
dp.remove_error_handler(self.error_handler_context)
|
|
self.reset()
|
|
|
|
dp.update_queue.put(error)
|
|
sleep(0.1)
|
|
assert self.received is None
|
|
|
|
def test_double_add_error_handler(self, dp, caplog):
|
|
dp.add_error_handler(self.error_handler_context)
|
|
with caplog.at_level(logging.DEBUG):
|
|
dp.add_error_handler(self.error_handler_context)
|
|
assert len(caplog.records) == 1
|
|
assert caplog.records[-1].getMessage().startswith('The callback is already registered')
|
|
|
|
def test_construction_with_bad_persistence(self, caplog, bot):
|
|
class my_per:
|
|
def __init__(self):
|
|
self.store_data = PersistenceInput(False, False, False, False)
|
|
|
|
with pytest.raises(
|
|
TypeError, match='persistence must be based on telegram.ext.BasePersistence'
|
|
):
|
|
Dispatcher(bot, None, persistence=my_per())
|
|
|
|
def test_error_handler_that_raises_errors(self, dp):
|
|
"""
|
|
Make sure that errors raised in error handlers don't break the main loop of the dispatcher
|
|
"""
|
|
handler_raise_error = MessageHandler(Filters.all, self.callback_raise_error)
|
|
handler_increase_count = MessageHandler(Filters.all, self.callback_increase_count)
|
|
error = TelegramError('Unauthorized.')
|
|
|
|
dp.add_error_handler(self.error_handler_raise_error)
|
|
|
|
# From errors caused by handlers
|
|
dp.add_handler(handler_raise_error)
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
|
|
# From errors in the update_queue
|
|
dp.remove_handler(handler_raise_error)
|
|
dp.add_handler(handler_increase_count)
|
|
dp.update_queue.put(error)
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
|
|
assert self.count == 1
|
|
|
|
@pytest.mark.parametrize(['run_async', 'expected_output'], [(True, 5), (False, 0)])
|
|
def test_default_run_async_error_handler(self, dp, monkeypatch, run_async, expected_output):
|
|
def mock_async_err_handler(*args, **kwargs):
|
|
self.count = 5
|
|
|
|
# set defaults value to dp.bot
|
|
dp.bot.defaults = Defaults(run_async=run_async)
|
|
try:
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
|
|
dp.add_error_handler(self.error_handler_context)
|
|
|
|
monkeypatch.setattr(dp, 'run_async', mock_async_err_handler)
|
|
dp.process_update(self.message_update)
|
|
|
|
assert self.count == expected_output
|
|
|
|
finally:
|
|
# reset dp.bot.defaults values
|
|
dp.bot.defaults = None
|
|
|
|
@pytest.mark.parametrize(
|
|
['run_async', 'expected_output'], [(True, 'running async'), (False, None)]
|
|
)
|
|
def test_default_run_async(self, monkeypatch, dp, run_async, expected_output):
|
|
def mock_run_async(*args, **kwargs):
|
|
self.received = 'running async'
|
|
|
|
# set defaults value to dp.bot
|
|
dp.bot.defaults = Defaults(run_async=run_async)
|
|
try:
|
|
dp.add_handler(MessageHandler(Filters.all, lambda u, c: None))
|
|
monkeypatch.setattr(dp, 'run_async', mock_run_async)
|
|
dp.process_update(self.message_update)
|
|
assert self.received == expected_output
|
|
|
|
finally:
|
|
# reset defaults value
|
|
dp.bot.defaults = None
|
|
|
|
def test_run_async_multiple(self, bot, dp, dp2):
|
|
def get_dispatcher_name(q):
|
|
q.put(current_thread().name)
|
|
|
|
q1 = Queue()
|
|
q2 = Queue()
|
|
|
|
dp.run_async(get_dispatcher_name, q1)
|
|
dp2.run_async(get_dispatcher_name, q2)
|
|
|
|
sleep(0.1)
|
|
|
|
name1 = q1.get()
|
|
name2 = q2.get()
|
|
|
|
assert name1 != name2
|
|
|
|
def test_async_raises_dispatcher_handler_stop(self, dp, recwarn):
|
|
def callback(update, context):
|
|
raise DispatcherHandlerStop()
|
|
|
|
dp.add_handler(MessageHandler(Filters.all, callback, run_async=True))
|
|
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert len(recwarn) == 1
|
|
assert str(recwarn[-1].message).startswith(
|
|
'DispatcherHandlerStop is not supported with async functions'
|
|
)
|
|
|
|
def test_add_async_handler(self, dp):
|
|
dp.add_handler(
|
|
MessageHandler(
|
|
Filters.all,
|
|
self.callback_received,
|
|
run_async=True,
|
|
)
|
|
)
|
|
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.received == self.message_update.message
|
|
|
|
def test_run_async_no_error_handler(self, dp, caplog):
|
|
def func():
|
|
raise RuntimeError('Async Error')
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
dp.run_async(func)
|
|
sleep(0.1)
|
|
assert len(caplog.records) == 1
|
|
assert caplog.records[-1].getMessage().startswith('No error handlers are registered')
|
|
|
|
def test_async_handler_async_error_handler_context(self, dp):
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error, run_async=True))
|
|
dp.add_error_handler(self.error_handler_context, run_async=True)
|
|
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(2)
|
|
assert self.received == self.message_update.message.text
|
|
|
|
def test_async_handler_error_handler_that_raises_error(self, dp, caplog):
|
|
handler = MessageHandler(Filters.all, self.callback_raise_error, run_async=True)
|
|
dp.add_handler(handler)
|
|
dp.add_error_handler(self.error_handler_raise_error, run_async=False)
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert len(caplog.records) == 1
|
|
assert (
|
|
caplog.records[-1].getMessage().startswith('An error was raised and an uncaught')
|
|
)
|
|
|
|
# Make sure that the main loop still runs
|
|
dp.remove_handler(handler)
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count, run_async=True))
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.count == 1
|
|
|
|
def test_async_handler_async_error_handler_that_raises_error(self, dp, caplog):
|
|
handler = MessageHandler(Filters.all, self.callback_raise_error, run_async=True)
|
|
dp.add_handler(handler)
|
|
dp.add_error_handler(self.error_handler_raise_error, run_async=True)
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert len(caplog.records) == 1
|
|
assert (
|
|
caplog.records[-1].getMessage().startswith('An error was raised and an uncaught')
|
|
)
|
|
|
|
# Make sure that the main loop still runs
|
|
dp.remove_handler(handler)
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count, run_async=True))
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.count == 1
|
|
|
|
def test_error_in_handler(self, dp):
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
|
|
dp.add_error_handler(self.error_handler_context)
|
|
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.received == self.message_update.message.text
|
|
|
|
def test_add_remove_handler(self, dp):
|
|
handler = MessageHandler(Filters.all, self.callback_increase_count)
|
|
dp.add_handler(handler)
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.count == 1
|
|
dp.remove_handler(handler)
|
|
dp.update_queue.put(self.message_update)
|
|
assert self.count == 1
|
|
|
|
def test_add_remove_handler_non_default_group(self, dp):
|
|
handler = MessageHandler(Filters.all, self.callback_increase_count)
|
|
dp.add_handler(handler, group=2)
|
|
with pytest.raises(KeyError):
|
|
dp.remove_handler(handler)
|
|
dp.remove_handler(handler, group=2)
|
|
|
|
def test_error_start_twice(self, dp):
|
|
assert dp.running
|
|
dp.start()
|
|
|
|
def test_handler_order_in_group(self, dp):
|
|
dp.add_handler(MessageHandler(Filters.photo, self.callback_set_count(1)))
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_set_count(2)))
|
|
dp.add_handler(MessageHandler(Filters.text, self.callback_set_count(3)))
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.count == 2
|
|
|
|
def test_groups(self, dp):
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count))
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=2)
|
|
dp.add_handler(MessageHandler(Filters.all, self.callback_increase_count), group=-1)
|
|
|
|
dp.update_queue.put(self.message_update)
|
|
sleep(0.1)
|
|
assert self.count == 3
|
|
|
|
def test_add_handler_errors(self, dp):
|
|
handler = 'not a handler'
|
|
with pytest.raises(TypeError, match='handler is not an instance of'):
|
|
dp.add_handler(handler)
|
|
|
|
handler = MessageHandler(Filters.photo, self.callback_set_count(1))
|
|
with pytest.raises(TypeError, match='group is not int'):
|
|
dp.add_handler(handler, 'one')
|
|
|
|
def test_flow_stop(self, dp, bot):
|
|
passed = []
|
|
|
|
def start1(b, u):
|
|
passed.append('start1')
|
|
raise DispatcherHandlerStop
|
|
|
|
def start2(b, u):
|
|
passed.append('start2')
|
|
|
|
def start3(b, u):
|
|
passed.append('start3')
|
|
|
|
def error(b, u, e):
|
|
passed.append('error')
|
|
passed.append(e)
|
|
|
|
update = Update(
|
|
1,
|
|
message=Message(
|
|
1,
|
|
None,
|
|
None,
|
|
None,
|
|
text='/start',
|
|
entities=[
|
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
|
],
|
|
bot=bot,
|
|
),
|
|
)
|
|
|
|
# If Stop raised handlers in other groups should not be called.
|
|
passed = []
|
|
dp.add_handler(CommandHandler('start', start1), 1)
|
|
dp.add_handler(CommandHandler('start', start3), 1)
|
|
dp.add_handler(CommandHandler('start', start2), 2)
|
|
dp.process_update(update)
|
|
assert passed == ['start1']
|
|
|
|
def test_exception_in_handler(self, dp, bot):
|
|
passed = []
|
|
err = Exception('General exception')
|
|
|
|
def start1(u, c):
|
|
passed.append('start1')
|
|
raise err
|
|
|
|
def start2(u, c):
|
|
passed.append('start2')
|
|
|
|
def start3(u, c):
|
|
passed.append('start3')
|
|
|
|
def error(u, c):
|
|
passed.append('error')
|
|
passed.append(c.error)
|
|
|
|
update = Update(
|
|
1,
|
|
message=Message(
|
|
1,
|
|
None,
|
|
None,
|
|
None,
|
|
text='/start',
|
|
entities=[
|
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
|
],
|
|
bot=bot,
|
|
),
|
|
)
|
|
|
|
# If an unhandled exception was caught, no further handlers from the same group should be
|
|
# called. Also, the error handler should be called and receive the exception
|
|
passed = []
|
|
dp.add_handler(CommandHandler('start', start1), 1)
|
|
dp.add_handler(CommandHandler('start', start2), 1)
|
|
dp.add_handler(CommandHandler('start', start3), 2)
|
|
dp.add_error_handler(error)
|
|
dp.process_update(update)
|
|
assert passed == ['start1', 'error', err, 'start3']
|
|
|
|
def test_telegram_error_in_handler(self, dp, bot):
|
|
passed = []
|
|
err = TelegramError('Telegram error')
|
|
|
|
def start1(u, c):
|
|
passed.append('start1')
|
|
raise err
|
|
|
|
def start2(u, c):
|
|
passed.append('start2')
|
|
|
|
def start3(u, c):
|
|
passed.append('start3')
|
|
|
|
def error(u, c):
|
|
passed.append('error')
|
|
passed.append(c.error)
|
|
|
|
update = Update(
|
|
1,
|
|
message=Message(
|
|
1,
|
|
None,
|
|
None,
|
|
None,
|
|
text='/start',
|
|
entities=[
|
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
|
],
|
|
bot=bot,
|
|
),
|
|
)
|
|
|
|
# If a TelegramException was caught, an error handler should be called and no further
|
|
# handlers from the same group should be called.
|
|
dp.add_handler(CommandHandler('start', start1), 1)
|
|
dp.add_handler(CommandHandler('start', start2), 1)
|
|
dp.add_handler(CommandHandler('start', start3), 2)
|
|
dp.add_error_handler(error)
|
|
dp.process_update(update)
|
|
assert passed == ['start1', 'error', err, 'start3']
|
|
assert passed[2] is err
|
|
|
|
def test_error_while_saving_chat_data(self, bot):
|
|
increment = []
|
|
|
|
class OwnPersistence(BasePersistence):
|
|
def get_callback_data(self):
|
|
return None
|
|
|
|
def update_callback_data(self, data):
|
|
raise Exception
|
|
|
|
def get_bot_data(self):
|
|
return {}
|
|
|
|
def update_bot_data(self, data):
|
|
raise Exception
|
|
|
|
def get_chat_data(self):
|
|
return defaultdict(dict)
|
|
|
|
def update_chat_data(self, chat_id, data):
|
|
raise Exception
|
|
|
|
def get_user_data(self):
|
|
return defaultdict(dict)
|
|
|
|
def update_user_data(self, user_id, data):
|
|
raise Exception
|
|
|
|
def get_conversations(self, name):
|
|
pass
|
|
|
|
def update_conversation(self, name, key, new_state):
|
|
pass
|
|
|
|
def refresh_user_data(self, user_id, user_data):
|
|
pass
|
|
|
|
def refresh_chat_data(self, chat_id, chat_data):
|
|
pass
|
|
|
|
def refresh_bot_data(self, bot_data):
|
|
pass
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
def start1(u, c):
|
|
pass
|
|
|
|
def error(u, c):
|
|
increment.append("error")
|
|
|
|
# If updating a user_data or chat_data from a persistence object throws an error,
|
|
# the error handler should catch it
|
|
|
|
update = Update(
|
|
1,
|
|
message=Message(
|
|
1,
|
|
None,
|
|
Chat(1, "lala"),
|
|
from_user=User(1, "Test", False),
|
|
text='/start',
|
|
entities=[
|
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
|
],
|
|
bot=bot,
|
|
),
|
|
)
|
|
my_persistence = OwnPersistence()
|
|
dp = Dispatcher(bot, None, persistence=my_persistence)
|
|
dp.add_handler(CommandHandler('start', start1))
|
|
dp.add_error_handler(error)
|
|
dp.process_update(update)
|
|
assert increment == ["error", "error", "error", "error"]
|
|
|
|
def test_flow_stop_in_error_handler(self, dp, bot):
|
|
passed = []
|
|
err = TelegramError('Telegram error')
|
|
|
|
def start1(u, c):
|
|
passed.append('start1')
|
|
raise err
|
|
|
|
def start2(u, c):
|
|
passed.append('start2')
|
|
|
|
def start3(u, c):
|
|
passed.append('start3')
|
|
|
|
def error(u, c):
|
|
passed.append('error')
|
|
passed.append(c.error)
|
|
raise DispatcherHandlerStop
|
|
|
|
update = Update(
|
|
1,
|
|
message=Message(
|
|
1,
|
|
None,
|
|
None,
|
|
None,
|
|
text='/start',
|
|
entities=[
|
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
|
],
|
|
bot=bot,
|
|
),
|
|
)
|
|
|
|
# If a TelegramException was caught, an error handler should be called and no further
|
|
# handlers from the same group should be called.
|
|
dp.add_handler(CommandHandler('start', start1), 1)
|
|
dp.add_handler(CommandHandler('start', start2), 1)
|
|
dp.add_handler(CommandHandler('start', start3), 2)
|
|
dp.add_error_handler(error)
|
|
dp.process_update(update)
|
|
assert passed == ['start1', 'error', err]
|
|
assert passed[2] is err
|
|
|
|
def test_sensible_worker_thread_names(self, dp2):
|
|
thread_names = [thread.name for thread in dp2._Dispatcher__async_threads]
|
|
for thread_name in thread_names:
|
|
assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:")
|
|
|
|
def test_error_while_persisting(self, dp, caplog):
|
|
class OwnPersistence(BasePersistence):
|
|
def update(self, data):
|
|
raise Exception('PersistenceError')
|
|
|
|
def update_callback_data(self, data):
|
|
self.update(data)
|
|
|
|
def update_bot_data(self, data):
|
|
self.update(data)
|
|
|
|
def update_chat_data(self, chat_id, data):
|
|
self.update(data)
|
|
|
|
def update_user_data(self, user_id, data):
|
|
self.update(data)
|
|
|
|
def get_chat_data(self):
|
|
pass
|
|
|
|
def get_bot_data(self):
|
|
pass
|
|
|
|
def get_user_data(self):
|
|
pass
|
|
|
|
def get_callback_data(self):
|
|
pass
|
|
|
|
def get_conversations(self, name):
|
|
pass
|
|
|
|
def update_conversation(self, name, key, new_state):
|
|
pass
|
|
|
|
def refresh_bot_data(self, bot_data):
|
|
pass
|
|
|
|
def refresh_user_data(self, user_id, user_data):
|
|
pass
|
|
|
|
def refresh_chat_data(self, chat_id, chat_data):
|
|
pass
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
def callback(update, context):
|
|
pass
|
|
|
|
test_flag = []
|
|
|
|
def error(update, context):
|
|
nonlocal test_flag
|
|
test_flag.append(str(context.error) == 'PersistenceError')
|
|
raise Exception('ErrorHandlingError')
|
|
|
|
update = Update(
|
|
1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
|
|
)
|
|
handler = MessageHandler(Filters.all, callback)
|
|
dp.add_handler(handler)
|
|
dp.add_error_handler(error)
|
|
|
|
dp.persistence = OwnPersistence()
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
dp.process_update(update)
|
|
|
|
assert test_flag == [True, True, True, True]
|
|
assert len(caplog.records) == 4
|
|
for record in caplog.records:
|
|
message = record.getMessage()
|
|
assert message.startswith('An error was raised and an uncaught')
|
|
|
|
def test_persisting_no_user_no_chat(self, dp):
|
|
class OwnPersistence(BasePersistence):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.test_flag_bot_data = False
|
|
self.test_flag_chat_data = False
|
|
self.test_flag_user_data = False
|
|
|
|
def update_bot_data(self, data):
|
|
self.test_flag_bot_data = True
|
|
|
|
def update_chat_data(self, chat_id, data):
|
|
self.test_flag_chat_data = True
|
|
|
|
def update_user_data(self, user_id, data):
|
|
self.test_flag_user_data = True
|
|
|
|
def update_conversation(self, name, key, new_state):
|
|
pass
|
|
|
|
def get_conversations(self, name):
|
|
pass
|
|
|
|
def get_user_data(self):
|
|
pass
|
|
|
|
def get_bot_data(self):
|
|
pass
|
|
|
|
def get_chat_data(self):
|
|
pass
|
|
|
|
def refresh_bot_data(self, bot_data):
|
|
pass
|
|
|
|
def refresh_user_data(self, user_id, user_data):
|
|
pass
|
|
|
|
def refresh_chat_data(self, chat_id, chat_data):
|
|
pass
|
|
|
|
def get_callback_data(self):
|
|
pass
|
|
|
|
def update_callback_data(self, data):
|
|
pass
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
def callback(update, context):
|
|
pass
|
|
|
|
handler = MessageHandler(Filters.all, callback)
|
|
dp.add_handler(handler)
|
|
dp.persistence = OwnPersistence()
|
|
|
|
update = Update(
|
|
1, message=Message(1, None, None, from_user=User(1, '', False), text='Text')
|
|
)
|
|
dp.process_update(update)
|
|
assert dp.persistence.test_flag_bot_data
|
|
assert dp.persistence.test_flag_user_data
|
|
assert not dp.persistence.test_flag_chat_data
|
|
|
|
dp.persistence.test_flag_bot_data = False
|
|
dp.persistence.test_flag_user_data = False
|
|
dp.persistence.test_flag_chat_data = False
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
|
|
dp.process_update(update)
|
|
assert dp.persistence.test_flag_bot_data
|
|
assert not dp.persistence.test_flag_user_data
|
|
assert dp.persistence.test_flag_chat_data
|
|
|
|
def test_update_persistence_once_per_update(self, monkeypatch, dp):
|
|
def update_persistence(*args, **kwargs):
|
|
self.count += 1
|
|
|
|
def dummy_callback(*args):
|
|
pass
|
|
|
|
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
|
|
|
|
for group in range(5):
|
|
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None))
|
|
dp.process_update(update)
|
|
assert self.count == 0
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text'))
|
|
dp.process_update(update)
|
|
assert self.count == 1
|
|
|
|
def test_update_persistence_all_async(self, monkeypatch, dp):
|
|
def update_persistence(*args, **kwargs):
|
|
self.count += 1
|
|
|
|
def dummy_callback(*args, **kwargs):
|
|
pass
|
|
|
|
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
|
|
monkeypatch.setattr(dp, 'run_async', dummy_callback)
|
|
|
|
for group in range(5):
|
|
dp.add_handler(
|
|
MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
|
|
)
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
|
|
dp.process_update(update)
|
|
assert self.count == 0
|
|
|
|
dp.bot.defaults = Defaults(run_async=True)
|
|
try:
|
|
for group in range(5):
|
|
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
|
|
dp.process_update(update)
|
|
assert self.count == 0
|
|
finally:
|
|
dp.bot.defaults = None
|
|
|
|
@pytest.mark.parametrize('run_async', [DEFAULT_FALSE, False])
|
|
def test_update_persistence_one_sync(self, monkeypatch, dp, run_async):
|
|
def update_persistence(*args, **kwargs):
|
|
self.count += 1
|
|
|
|
def dummy_callback(*args, **kwargs):
|
|
pass
|
|
|
|
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
|
|
monkeypatch.setattr(dp, 'run_async', dummy_callback)
|
|
|
|
for group in range(5):
|
|
dp.add_handler(
|
|
MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
|
|
)
|
|
dp.add_handler(MessageHandler(Filters.text, dummy_callback, run_async=run_async), group=5)
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
|
|
dp.process_update(update)
|
|
assert self.count == 1
|
|
|
|
@pytest.mark.parametrize('run_async,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)])
|
|
def test_update_persistence_defaults_async(self, monkeypatch, dp, run_async, expected):
|
|
def update_persistence(*args, **kwargs):
|
|
self.count += 1
|
|
|
|
def dummy_callback(*args, **kwargs):
|
|
pass
|
|
|
|
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
|
|
monkeypatch.setattr(dp, 'run_async', dummy_callback)
|
|
dp.bot.defaults = Defaults(run_async=run_async)
|
|
|
|
try:
|
|
for group in range(5):
|
|
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
|
|
|
|
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
|
|
dp.process_update(update)
|
|
assert self.count == expected
|
|
finally:
|
|
dp.bot.defaults = None
|
|
|
|
def test_custom_context_init(self, bot):
|
|
cc = ContextTypes(
|
|
context=CustomContext,
|
|
user_data=int,
|
|
chat_data=float,
|
|
bot_data=complex,
|
|
)
|
|
|
|
dispatcher = Dispatcher(bot, Queue(), context_types=cc)
|
|
|
|
assert isinstance(dispatcher.user_data[1], int)
|
|
assert isinstance(dispatcher.chat_data[1], float)
|
|
assert isinstance(dispatcher.bot_data, complex)
|
|
|
|
def test_custom_context_error_handler(self, bot):
|
|
def error_handler(_, context):
|
|
self.received = (
|
|
type(context),
|
|
type(context.user_data),
|
|
type(context.chat_data),
|
|
type(context.bot_data),
|
|
)
|
|
|
|
dispatcher = Dispatcher(
|
|
bot,
|
|
Queue(),
|
|
context_types=ContextTypes(
|
|
context=CustomContext, bot_data=int, user_data=float, chat_data=complex
|
|
),
|
|
)
|
|
dispatcher.add_error_handler(error_handler)
|
|
dispatcher.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
|
|
|
|
dispatcher.process_update(self.message_update)
|
|
sleep(0.1)
|
|
assert self.received == (CustomContext, float, complex, int)
|
|
|
|
def test_custom_context_handler_callback(self, bot):
|
|
def callback(_, context):
|
|
self.received = (
|
|
type(context),
|
|
type(context.user_data),
|
|
type(context.chat_data),
|
|
type(context.bot_data),
|
|
)
|
|
|
|
dispatcher = Dispatcher(
|
|
bot,
|
|
Queue(),
|
|
context_types=ContextTypes(
|
|
context=CustomContext, bot_data=int, user_data=float, chat_data=complex
|
|
),
|
|
)
|
|
dispatcher.add_handler(MessageHandler(Filters.all, callback))
|
|
|
|
dispatcher.process_update(self.message_update)
|
|
sleep(0.1)
|
|
assert self.received == (CustomContext, float, complex, int)
|