From 3842846b2dcd0a0621e62e822841e42d4b76d4c7 Mon Sep 17 00:00:00 2001 From: Bibo-Joshi Date: Tue, 28 Jul 2020 09:10:32 +0200 Subject: [PATCH] Refactor Handling of Message VS Update Filters (#2032) * Refactor handling of message vs update filters * address review --- telegram/ext/__init__.py | 14 +-- telegram/ext/filters.py | 212 ++++++++++++++++++++--------------- telegram/files/venue.py | 2 +- tests/conftest.py | 15 ++- tests/test_filters.py | 32 ++++-- tests/test_messagehandler.py | 5 +- 6 files changed, 163 insertions(+), 117 deletions(-) diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index e77b55673..a39b067e9 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -29,7 +29,7 @@ from .updater import Updater from .callbackqueryhandler import CallbackQueryHandler from .choseninlineresulthandler import ChosenInlineResultHandler from .inlinequeryhandler import InlineQueryHandler -from .filters import BaseFilter, Filters +from .filters import BaseFilter, MessageFilter, UpdateFilter, Filters from .messagehandler import MessageHandler from .commandhandler import CommandHandler, PrefixHandler from .regexhandler import RegexHandler @@ -47,9 +47,9 @@ from .defaults import Defaults __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler', 'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler', - 'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler', - 'StringRegexHandler', 'TypeHandler', 'ConversationHandler', - 'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', - 'DispatcherHandlerStop', 'run_async', 'CallbackContext', 'BasePersistence', - 'PicklePersistence', 'DictPersistence', 'PrefixHandler', 'PollAnswerHandler', - 'PollHandler', 'Defaults') + 'MessageHandler', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'Filters', + 'RegexHandler', 'StringCommandHandler', 'StringRegexHandler', 'TypeHandler', + 'ConversationHandler', 'PreCheckoutQueryHandler', 'ShippingQueryHandler', + 'MessageQueue', 'DelayQueue', 'DispatcherHandlerStop', 'run_async', 'CallbackContext', + 'BasePersistence', 'PicklePersistence', 'DictPersistence', 'PrefixHandler', + 'PollAnswerHandler', 'PollHandler', 'Defaults') diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index d4e9cafe3..964448904 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -25,13 +25,14 @@ from threading import Lock from telegram import Chat, Update, MessageEntity -__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter'] +__all__ = ['Filters', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'InvertedFilter', + 'MergedFilter'] class BaseFilter(ABC): - """Base class for all Message Filters. + """Base class for all Filters. - Subclassing from this class filters to be combined using bitwise operators: + Filters subclassing from this class can combined using bitwise operators: And: @@ -56,14 +57,17 @@ class BaseFilter(ABC): >>> Filters.regex(r'(a?x)') | Filters.regex(r'(b?x)') - With a message.text of `x`, will only ever return the matches for the first filter, + With ``message.text == x``, will only ever return the matches for the first filter, since the second one is never evaluated. - If you want to create your own filters create a class inheriting from this class and implement - a :meth:`filter` method that returns a boolean: :obj:`True` if the message should be - handled, :obj:`False` otherwise. Note that the filters work only as class instances, not - actual class objects (so remember to initialize your filter classes). + If you want to create your own filters create a class inheriting from either + :class:`MessageFilter` or :class:`UpdateFilter` and implement a :meth:``filter`` method that + returns a boolean: :obj:`True` if the message should be + handled, :obj:`False` otherwise. + Note that the filters work only as class instances, not + actual class objects (so remember to + initialize your filter classes). By default the filters name (what will get printed when converted to a string for display) will be the class name. If you want to overwrite this assign a better name to the :attr:`name` @@ -71,8 +75,6 @@ class BaseFilter(ABC): Attributes: name (:obj:`str`): Name for this filter. Defaults to the type of filter. - update_filter (:obj:`bool`): Whether this filter should work on update. If :obj:`False` it - will run the filter on :attr:`update.effective_message`. Default is :obj:`False`. data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should return a dict with lists. The dict will be merged with :class:`telegram.ext.CallbackContext`'s internal dict in most cases @@ -80,14 +82,11 @@ class BaseFilter(ABC): """ name = None - update_filter = False data_filter = False + @abstractmethod def __call__(self, update): - if self.update_filter: - return self.filter(update) - else: - return self.filter(update.effective_message) + pass def __and__(self, other): return MergedFilter(self, and_filter=other) @@ -104,14 +103,59 @@ class BaseFilter(ABC): self.name = self.__class__.__name__ return self.name + +class MessageFilter(BaseFilter, ABC): + """Base class for all Message Filters. In contrast to :class:`UpdateFilter`, the object passed + to :meth:`filter` is ``update.effective_message``. + + Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters. + + Attributes: + name (:obj:`str`): Name for this filter. Defaults to the type of filter. + data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should + return a dict with lists. The dict will be merged with + :class:`telegram.ext.CallbackContext`'s internal dict in most cases + (depends on the handler). + + """ + def __call__(self, update): + return self.filter(update.effective_message) + + @abstractmethod + def filter(self, message): + """This method must be overwritten. + + Args: + message (:class:`telegram.Message`): The message that is tested. + + Returns: + :obj:`dict` or :obj:`bool` + + """ + + +class UpdateFilter(BaseFilter, ABC): + """Base class for all Update Filters. In contrast to :class:`UpdateFilter`, the object + passed to :meth:`filter` is ``update``, which allows to create filters like + :attr:`Filters.update.edited_message`. + + Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters. + + Attributes: + name (:obj:`str`): Name for this filter. Defaults to the type of filter. + data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should + return a dict with lists. The dict will be merged with + :class:`telegram.ext.CallbackContext`'s internal dict in most cases + (depends on the handler). + + """ + def __call__(self, update): + return self.filter(update) + @abstractmethod def filter(self, update): """This method must be overwritten. - Note: - If :attr:`update_filter` is :obj:`False` then the first argument is `message` and of - type :class:`telegram.Message`. - Args: update (:class:`telegram.Update`): The update that is tested. @@ -121,15 +165,13 @@ class BaseFilter(ABC): """ -class InvertedFilter(BaseFilter): +class InvertedFilter(UpdateFilter): """Represents a filter that has been inverted. Args: f: The filter to invert. """ - update_filter = True - def __init__(self, f): self.f = f @@ -140,7 +182,7 @@ class InvertedFilter(BaseFilter): return "".format(self.f) -class MergedFilter(BaseFilter): +class MergedFilter(UpdateFilter): """Represents a filter consisting of two other filters. Args: @@ -149,8 +191,6 @@ class MergedFilter(BaseFilter): or_filter: Optional filter to "or" with base_filter. Mutually exclusive with and_filter. """ - update_filter = True - def __init__(self, base_filter, and_filter=None, or_filter=None): self.base_filter = base_filter if self.base_filter.data_filter: @@ -215,13 +255,13 @@ class MergedFilter(BaseFilter): self.and_filter or self.or_filter) -class _DiceEmoji(BaseFilter): +class _DiceEmoji(MessageFilter): def __init__(self, emoji=None, name=None): self.name = 'Filters.dice.{}'.format(name) if name else 'Filters.dice' self.emoji = emoji - class _DiceValues(BaseFilter): + class _DiceValues(MessageFilter): def __init__(self, values, name, emoji=None): self.values = [values] if isinstance(values, int) else values @@ -248,7 +288,8 @@ class _DiceEmoji(BaseFilter): class Filters: - """Predefined filters for use as the `filter` argument of :class:`telegram.ext.MessageHandler`. + """Predefined filters for use as the ``filter`` argument of + :class:`telegram.ext.MessageHandler`. Examples: Use ``MessageHandler(Filters.video, callback_method)`` to filter all video @@ -256,7 +297,7 @@ class Filters: """ - class _All(BaseFilter): + class _All(MessageFilter): name = 'Filters.all' def filter(self, message): @@ -265,10 +306,10 @@ class Filters: all = _All() """All Messages.""" - class _Text(BaseFilter): + class _Text(MessageFilter): name = 'Filters.text' - class _TextStrings(BaseFilter): + class _TextStrings(MessageFilter): def __init__(self, strings): self.strings = strings @@ -316,10 +357,10 @@ class Filters: exact matches are allowed. If not specified, will allow any text message. """ - class _Caption(BaseFilter): + class _Caption(MessageFilter): name = 'Filters.caption' - class _CaptionStrings(BaseFilter): + class _CaptionStrings(MessageFilter): def __init__(self, strings): self.strings = strings @@ -351,10 +392,10 @@ class Filters: exact matches are allowed. If not specified, will allow any message with a caption. """ - class _Command(BaseFilter): + class _Command(MessageFilter): name = 'Filters.command' - class _CommandOnlyStart(BaseFilter): + class _CommandOnlyStart(MessageFilter): def __init__(self, only_start): self.only_start = only_start @@ -393,7 +434,7 @@ class Filters: command. Defaults to :obj:`True`. """ - class regex(BaseFilter): + class regex(MessageFilter): """ Filters updates by searching for an occurrence of ``pattern`` in the message text. The ``re.search()`` function is used to determine whether an update should be filtered. @@ -438,7 +479,7 @@ class Filters: return {'matches': [match]} return {} - class _Reply(BaseFilter): + class _Reply(MessageFilter): name = 'Filters.reply' def filter(self, message): @@ -447,7 +488,7 @@ class Filters: reply = _Reply() """Messages that are a reply to another message.""" - class _Audio(BaseFilter): + class _Audio(MessageFilter): name = 'Filters.audio' def filter(self, message): @@ -456,10 +497,10 @@ class Filters: audio = _Audio() """Messages that contain :class:`telegram.Audio`.""" - class _Document(BaseFilter): + class _Document(MessageFilter): name = 'Filters.document' - class category(BaseFilter): + class category(MessageFilter): """Filters documents by their category in the mime-type attribute. Note: @@ -492,7 +533,7 @@ class Filters: video = category('video/') text = category('text/') - class mime_type(BaseFilter): + class mime_type(MessageFilter): """This Filter filters documents by their mime-type attribute Note: @@ -592,7 +633,7 @@ officedocument.wordprocessingml.document")``- zip: Same as ``Filters.document.mime_type("application/zip")``- """ - class _Animation(BaseFilter): + class _Animation(MessageFilter): name = 'Filters.animation' def filter(self, message): @@ -601,7 +642,7 @@ officedocument.wordprocessingml.document")``- animation = _Animation() """Messages that contain :class:`telegram.Animation`.""" - class _Photo(BaseFilter): + class _Photo(MessageFilter): name = 'Filters.photo' def filter(self, message): @@ -610,7 +651,7 @@ officedocument.wordprocessingml.document")``- photo = _Photo() """Messages that contain :class:`telegram.PhotoSize`.""" - class _Sticker(BaseFilter): + class _Sticker(MessageFilter): name = 'Filters.sticker' def filter(self, message): @@ -619,7 +660,7 @@ officedocument.wordprocessingml.document")``- sticker = _Sticker() """Messages that contain :class:`telegram.Sticker`.""" - class _Video(BaseFilter): + class _Video(MessageFilter): name = 'Filters.video' def filter(self, message): @@ -628,7 +669,7 @@ officedocument.wordprocessingml.document")``- video = _Video() """Messages that contain :class:`telegram.Video`.""" - class _Voice(BaseFilter): + class _Voice(MessageFilter): name = 'Filters.voice' def filter(self, message): @@ -637,7 +678,7 @@ officedocument.wordprocessingml.document")``- voice = _Voice() """Messages that contain :class:`telegram.Voice`.""" - class _VideoNote(BaseFilter): + class _VideoNote(MessageFilter): name = 'Filters.video_note' def filter(self, message): @@ -646,7 +687,7 @@ officedocument.wordprocessingml.document")``- video_note = _VideoNote() """Messages that contain :class:`telegram.VideoNote`.""" - class _Contact(BaseFilter): + class _Contact(MessageFilter): name = 'Filters.contact' def filter(self, message): @@ -655,7 +696,7 @@ officedocument.wordprocessingml.document")``- contact = _Contact() """Messages that contain :class:`telegram.Contact`.""" - class _Location(BaseFilter): + class _Location(MessageFilter): name = 'Filters.location' def filter(self, message): @@ -664,7 +705,7 @@ officedocument.wordprocessingml.document")``- location = _Location() """Messages that contain :class:`telegram.Location`.""" - class _Venue(BaseFilter): + class _Venue(MessageFilter): name = 'Filters.venue' def filter(self, message): @@ -673,7 +714,7 @@ officedocument.wordprocessingml.document")``- venue = _Venue() """Messages that contain :class:`telegram.Venue`.""" - class _StatusUpdate(BaseFilter): + class _StatusUpdate(UpdateFilter): """Subset for messages containing a status update. Examples: @@ -681,9 +722,7 @@ officedocument.wordprocessingml.document")``- ``Filters.status_update`` for all status update messages. """ - update_filter = True - - class _NewChatMembers(BaseFilter): + class _NewChatMembers(MessageFilter): name = 'Filters.status_update.new_chat_members' def filter(self, message): @@ -692,7 +731,7 @@ officedocument.wordprocessingml.document")``- new_chat_members = _NewChatMembers() """Messages that contain :attr:`telegram.Message.new_chat_members`.""" - class _LeftChatMember(BaseFilter): + class _LeftChatMember(MessageFilter): name = 'Filters.status_update.left_chat_member' def filter(self, message): @@ -701,7 +740,7 @@ officedocument.wordprocessingml.document")``- left_chat_member = _LeftChatMember() """Messages that contain :attr:`telegram.Message.left_chat_member`.""" - class _NewChatTitle(BaseFilter): + class _NewChatTitle(MessageFilter): name = 'Filters.status_update.new_chat_title' def filter(self, message): @@ -710,7 +749,7 @@ officedocument.wordprocessingml.document")``- new_chat_title = _NewChatTitle() """Messages that contain :attr:`telegram.Message.new_chat_title`.""" - class _NewChatPhoto(BaseFilter): + class _NewChatPhoto(MessageFilter): name = 'Filters.status_update.new_chat_photo' def filter(self, message): @@ -719,7 +758,7 @@ officedocument.wordprocessingml.document")``- new_chat_photo = _NewChatPhoto() """Messages that contain :attr:`telegram.Message.new_chat_photo`.""" - class _DeleteChatPhoto(BaseFilter): + class _DeleteChatPhoto(MessageFilter): name = 'Filters.status_update.delete_chat_photo' def filter(self, message): @@ -728,7 +767,7 @@ officedocument.wordprocessingml.document")``- delete_chat_photo = _DeleteChatPhoto() """Messages that contain :attr:`telegram.Message.delete_chat_photo`.""" - class _ChatCreated(BaseFilter): + class _ChatCreated(MessageFilter): name = 'Filters.status_update.chat_created' def filter(self, message): @@ -740,7 +779,7 @@ officedocument.wordprocessingml.document")``- :attr: `telegram.Message.supergroup_chat_created` or :attr: `telegram.Message.channel_chat_created`.""" - class _Migrate(BaseFilter): + class _Migrate(MessageFilter): name = 'Filters.status_update.migrate' def filter(self, message): @@ -750,7 +789,7 @@ officedocument.wordprocessingml.document")``- """Messages that contain :attr:`telegram.Message.migrate_from_chat_id` or :attr:`telegram.Message.migrate_to_chat_id`.""" - class _PinnedMessage(BaseFilter): + class _PinnedMessage(MessageFilter): name = 'Filters.status_update.pinned_message' def filter(self, message): @@ -759,7 +798,7 @@ officedocument.wordprocessingml.document")``- pinned_message = _PinnedMessage() """Messages that contain :attr:`telegram.Message.pinned_message`.""" - class _ConnectedWebsite(BaseFilter): + class _ConnectedWebsite(MessageFilter): name = 'Filters.status_update.connected_website' def filter(self, message): @@ -806,7 +845,7 @@ officedocument.wordprocessingml.document")``- :attr:`telegram.Message.pinned_message`. """ - class _Forwarded(BaseFilter): + class _Forwarded(MessageFilter): name = 'Filters.forwarded' def filter(self, message): @@ -815,7 +854,7 @@ officedocument.wordprocessingml.document")``- forwarded = _Forwarded() """Messages that are forwarded.""" - class _Game(BaseFilter): + class _Game(MessageFilter): name = 'Filters.game' def filter(self, message): @@ -824,7 +863,7 @@ officedocument.wordprocessingml.document")``- game = _Game() """Messages that contain :class:`telegram.Game`.""" - class entity(BaseFilter): + class entity(MessageFilter): """ Filters messages to only allow those which have a :class:`telegram.MessageEntity` where their `type` matches `entity_type`. @@ -846,7 +885,7 @@ officedocument.wordprocessingml.document")``- """""" # remove method from docs return any(entity.type == self.entity_type for entity in message.entities) - class caption_entity(BaseFilter): + class caption_entity(MessageFilter): """ Filters media messages to only allow those which have a :class:`telegram.MessageEntity` where their `type` matches `entity_type`. @@ -868,7 +907,7 @@ officedocument.wordprocessingml.document")``- """""" # remove method from docs return any(entity.type == self.entity_type for entity in message.caption_entities) - class _Private(BaseFilter): + class _Private(MessageFilter): name = 'Filters.private' def filter(self, message): @@ -877,7 +916,7 @@ officedocument.wordprocessingml.document")``- private = _Private() """Messages sent in a private chat.""" - class _Group(BaseFilter): + class _Group(MessageFilter): name = 'Filters.group' def filter(self, message): @@ -886,7 +925,7 @@ officedocument.wordprocessingml.document")``- group = _Group() """Messages sent in a group chat.""" - class user(BaseFilter): + class user(MessageFilter): """Filters messages to allow only those which are from specified user ID(s) or username(s). @@ -1053,7 +1092,7 @@ officedocument.wordprocessingml.document")``- return self.allow_empty return False - class via_bot(BaseFilter): + class via_bot(MessageFilter): """Filters messages to allow only those which are from specified via_bot ID(s) or username(s). @@ -1220,7 +1259,7 @@ officedocument.wordprocessingml.document")``- return self.allow_empty return False - class chat(BaseFilter): + class chat(MessageFilter): """Filters messages to allow only those which are from a specified chat ID or username. Examples: @@ -1387,7 +1426,7 @@ officedocument.wordprocessingml.document")``- return self.allow_empty return False - class _Invoice(BaseFilter): + class _Invoice(MessageFilter): name = 'Filters.invoice' def filter(self, message): @@ -1396,7 +1435,7 @@ officedocument.wordprocessingml.document")``- invoice = _Invoice() """Messages that contain :class:`telegram.Invoice`.""" - class _SuccessfulPayment(BaseFilter): + class _SuccessfulPayment(MessageFilter): name = 'Filters.successful_payment' def filter(self, message): @@ -1405,7 +1444,7 @@ officedocument.wordprocessingml.document")``- successful_payment = _SuccessfulPayment() """Messages that confirm a :class:`telegram.SuccessfulPayment`.""" - class _PassportData(BaseFilter): + class _PassportData(MessageFilter): name = 'Filters.passport_data' def filter(self, message): @@ -1414,7 +1453,7 @@ officedocument.wordprocessingml.document")``- passport_data = _PassportData() """Messages that contain a :class:`telegram.PassportData`""" - class _Poll(BaseFilter): + class _Poll(MessageFilter): name = 'Filters.poll' def filter(self, message): @@ -1457,7 +1496,7 @@ officedocument.wordprocessingml.document")``- as for :attr:`Filters.dice`. """ - class language(BaseFilter): + class language(MessageFilter): """Filters messages to only allow those which are from users with a certain language code. Note: @@ -1486,48 +1525,42 @@ officedocument.wordprocessingml.document")``- return message.from_user.language_code and any( [message.from_user.language_code.startswith(x) for x in self.lang]) - class _UpdateType(BaseFilter): - update_filter = True + class _UpdateType(UpdateFilter): name = 'Filters.update' - class _Message(BaseFilter): + class _Message(UpdateFilter): name = 'Filters.update.message' - update_filter = True def filter(self, update): return update.message is not None message = _Message() - class _EditedMessage(BaseFilter): + class _EditedMessage(UpdateFilter): name = 'Filters.update.edited_message' - update_filter = True def filter(self, update): return update.edited_message is not None edited_message = _EditedMessage() - class _Messages(BaseFilter): + class _Messages(UpdateFilter): name = 'Filters.update.messages' - update_filter = True def filter(self, update): return update.message is not None or update.edited_message is not None messages = _Messages() - class _ChannelPost(BaseFilter): + class _ChannelPost(UpdateFilter): name = 'Filters.update.channel_post' - update_filter = True def filter(self, update): return update.channel_post is not None channel_post = _ChannelPost() - class _EditedChannelPost(BaseFilter): - update_filter = True + class _EditedChannelPost(UpdateFilter): name = 'Filters.update.edited_channel_post' def filter(self, update): @@ -1535,8 +1568,7 @@ officedocument.wordprocessingml.document")``- edited_channel_post = _EditedChannelPost() - class _ChannelPosts(BaseFilter): - update_filter = True + class _ChannelPosts(UpdateFilter): name = 'Filters.update.channel_posts' def filter(self, update): diff --git a/telegram/files/venue.py b/telegram/files/venue.py index a54d79785..142a0e9bf 100644 --- a/telegram/files/venue.py +++ b/telegram/files/venue.py @@ -25,7 +25,7 @@ class Venue(TelegramObject): """This object represents a venue. Objects of this class are comparable in terms of equality. Two objects of this class are - considered equal, if their :attr:`location` and :attr:`title`are equal. + considered equal, if their :attr:`location` and :attr:`title` are equal. Attributes: location (:class:`telegram.Location`): Venue location. diff --git a/tests/conftest.py b/tests/conftest.py index 9a6d8fbe6..ee9e70697 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ import pytz from telegram import (Bot, Message, User, Chat, MessageEntity, Update, InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery, ChosenInlineResult) -from telegram.ext import Dispatcher, JobQueue, Updater, BaseFilter, Defaults +from telegram.ext import Dispatcher, JobQueue, Updater, MessageFilter, Defaults, UpdateFilter from telegram.error import BadRequest from tests.bots import get_bot @@ -239,13 +239,18 @@ def make_command_update(message, edited=False, **kwargs): return make_message_update(message, make_command_message, edited, **kwargs) -@pytest.fixture(scope='function') -def mock_filter(): - class MockFilter(BaseFilter): +@pytest.fixture(scope='class', + params=[ + {'class': MessageFilter}, + {'class': UpdateFilter} + ], + ids=['MessageFilter', 'UpdateFilter']) +def mock_filter(request): + class MockFilter(request.param['class']): def __init__(self): self.tested = False - def filter(self, message): + def filter(self, _): self.tested = True return MockFilter() diff --git a/tests/test_filters.py b/tests/test_filters.py index 03847413d..fad30709d 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -21,7 +21,7 @@ import datetime import pytest from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice -from telegram.ext import Filters, BaseFilter +from telegram.ext import Filters, BaseFilter, MessageFilter, UpdateFilter import re @@ -37,6 +37,16 @@ def message_entity(request): return MessageEntity(request.param, 0, 0, url='', user='') +@pytest.fixture(scope='class', + params=[ + {'class': MessageFilter}, + {'class': UpdateFilter} + ], + ids=['MessageFilter', 'UpdateFilter']) +def base_class(request): + return request.param['class'] + + class TestFilters: def test_filters_all(self, update): assert Filters.all(update) @@ -962,8 +972,8 @@ class TestFilters: with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'): _CustomFilter() - def test_custom_unnamed_filter(self, update): - class Unnamed(BaseFilter): + def test_custom_unnamed_filter(self, update, base_class): + class Unnamed(base_class): def filter(self, mes): return True @@ -1009,14 +1019,14 @@ class TestFilters: assert Filters.update.channel_posts(update) assert Filters.update(update) - def test_merged_short_circuit_and(self, update): + def test_merged_short_circuit_and(self, update, base_class): update.message.text = '/test' update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] class TestException(Exception): pass - class RaisingFilter(BaseFilter): + class RaisingFilter(base_class): def filter(self, _): raise TestException @@ -1029,13 +1039,13 @@ class TestFilters: update.message.entities = [] (Filters.command & raising_filter)(update) - def test_merged_short_circuit_or(self, update): + def test_merged_short_circuit_or(self, update, base_class): update.message.text = 'test' class TestException(Exception): pass - class RaisingFilter(BaseFilter): + class RaisingFilter(base_class): def filter(self, _): raise TestException @@ -1048,11 +1058,11 @@ class TestFilters: update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] (Filters.command | raising_filter)(update) - def test_merged_data_merging_and(self, update): + def test_merged_data_merging_and(self, update, base_class): update.message.text = '/test' update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - class DataFilter(BaseFilter): + class DataFilter(base_class): data_filter = True def __init__(self, data): @@ -1072,10 +1082,10 @@ class TestFilters: result = (Filters.command & DataFilter('blah'))(update) assert not result - def test_merged_data_merging_or(self, update): + def test_merged_data_merging_or(self, update, base_class): update.message.text = '/test' - class DataFilter(BaseFilter): + class DataFilter(base_class): data_filter = True def __init__(self, data): diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py index 12f78c23e..359289995 100644 --- a/tests/test_messagehandler.py +++ b/tests/test_messagehandler.py @@ -24,7 +24,7 @@ from telegram.utils.deprecate import TelegramDeprecationWarning from telegram import (Message, Update, Chat, Bot, User, CallbackQuery, InlineQuery, ChosenInlineResult, ShippingQuery, PreCheckoutQuery) -from telegram.ext import Filters, MessageHandler, CallbackContext, JobQueue, BaseFilter +from telegram.ext import Filters, MessageHandler, CallbackContext, JobQueue, UpdateFilter message = Message(1, User(1, '', False), None, Chat(1, ''), text='Text') @@ -163,8 +163,7 @@ class TestMessageHandler: def test_callback_query_with_filter(self, message): - class TestFilter(BaseFilter): - update_filter = True + class TestFilter(UpdateFilter): flag = False def filter(self, u):