Refactor Handling of Message VS Update Filters (#2032)

* Refactor handling of message vs update filters

* address review
This commit is contained in:
Bibo-Joshi 2020-07-28 09:10:32 +02:00
parent 7daddfb54d
commit 3842846b2d
6 changed files with 163 additions and 117 deletions

View file

@ -29,7 +29,7 @@ from .updater import Updater
from .callbackqueryhandler import CallbackQueryHandler from .callbackqueryhandler import CallbackQueryHandler
from .choseninlineresulthandler import ChosenInlineResultHandler from .choseninlineresulthandler import ChosenInlineResultHandler
from .inlinequeryhandler import InlineQueryHandler from .inlinequeryhandler import InlineQueryHandler
from .filters import BaseFilter, Filters from .filters import BaseFilter, MessageFilter, UpdateFilter, Filters
from .messagehandler import MessageHandler from .messagehandler import MessageHandler
from .commandhandler import CommandHandler, PrefixHandler from .commandhandler import CommandHandler, PrefixHandler
from .regexhandler import RegexHandler from .regexhandler import RegexHandler
@ -47,9 +47,9 @@ from .defaults import Defaults
__all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler', __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler',
'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler', 'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler',
'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler', 'MessageHandler', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'Filters',
'StringRegexHandler', 'TypeHandler', 'ConversationHandler', 'RegexHandler', 'StringCommandHandler', 'StringRegexHandler', 'TypeHandler',
'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', 'ConversationHandler', 'PreCheckoutQueryHandler', 'ShippingQueryHandler',
'DispatcherHandlerStop', 'run_async', 'CallbackContext', 'BasePersistence', 'MessageQueue', 'DelayQueue', 'DispatcherHandlerStop', 'run_async', 'CallbackContext',
'PicklePersistence', 'DictPersistence', 'PrefixHandler', 'PollAnswerHandler', 'BasePersistence', 'PicklePersistence', 'DictPersistence', 'PrefixHandler',
'PollHandler', 'Defaults') 'PollAnswerHandler', 'PollHandler', 'Defaults')

View file

@ -25,13 +25,14 @@ from threading import Lock
from telegram import Chat, Update, MessageEntity from telegram import Chat, Update, MessageEntity
__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter'] __all__ = ['Filters', 'BaseFilter', 'MessageFilter', 'UpdateFilter', 'InvertedFilter',
'MergedFilter']
class BaseFilter(ABC): class BaseFilter(ABC):
"""Base class for all Message Filters. """Base class for all Filters.
Subclassing from this class filters to be combined using bitwise operators: Filters subclassing from this class can combined using bitwise operators:
And: And:
@ -56,14 +57,17 @@ class BaseFilter(ABC):
>>> Filters.regex(r'(a?x)') | Filters.regex(r'(b?x)') >>> Filters.regex(r'(a?x)') | Filters.regex(r'(b?x)')
With a message.text of `x`, will only ever return the matches for the first filter, With ``message.text == x``, will only ever return the matches for the first filter,
since the second one is never evaluated. since the second one is never evaluated.
If you want to create your own filters create a class inheriting from this class and implement If you want to create your own filters create a class inheriting from either
a :meth:`filter` method that returns a boolean: :obj:`True` if the message should be :class:`MessageFilter` or :class:`UpdateFilter` and implement a :meth:``filter`` method that
handled, :obj:`False` otherwise. Note that the filters work only as class instances, not returns a boolean: :obj:`True` if the message should be
actual class objects (so remember to initialize your filter classes). handled, :obj:`False` otherwise.
Note that the filters work only as class instances, not
actual class objects (so remember to
initialize your filter classes).
By default the filters name (what will get printed when converted to a string for display) By default the filters name (what will get printed when converted to a string for display)
will be the class name. If you want to overwrite this assign a better name to the :attr:`name` will be the class name. If you want to overwrite this assign a better name to the :attr:`name`
@ -71,8 +75,6 @@ class BaseFilter(ABC):
Attributes: Attributes:
name (:obj:`str`): Name for this filter. Defaults to the type of filter. name (:obj:`str`): Name for this filter. Defaults to the type of filter.
update_filter (:obj:`bool`): Whether this filter should work on update. If :obj:`False` it
will run the filter on :attr:`update.effective_message`. Default is :obj:`False`.
data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should
return a dict with lists. The dict will be merged with return a dict with lists. The dict will be merged with
:class:`telegram.ext.CallbackContext`'s internal dict in most cases :class:`telegram.ext.CallbackContext`'s internal dict in most cases
@ -80,14 +82,11 @@ class BaseFilter(ABC):
""" """
name = None name = None
update_filter = False
data_filter = False data_filter = False
@abstractmethod
def __call__(self, update): def __call__(self, update):
if self.update_filter: pass
return self.filter(update)
else:
return self.filter(update.effective_message)
def __and__(self, other): def __and__(self, other):
return MergedFilter(self, and_filter=other) return MergedFilter(self, and_filter=other)
@ -104,14 +103,59 @@ class BaseFilter(ABC):
self.name = self.__class__.__name__ self.name = self.__class__.__name__
return self.name return self.name
class MessageFilter(BaseFilter, ABC):
"""Base class for all Message Filters. In contrast to :class:`UpdateFilter`, the object passed
to :meth:`filter` is ``update.effective_message``.
Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters.
Attributes:
name (:obj:`str`): Name for this filter. Defaults to the type of filter.
data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should
return a dict with lists. The dict will be merged with
:class:`telegram.ext.CallbackContext`'s internal dict in most cases
(depends on the handler).
"""
def __call__(self, update):
return self.filter(update.effective_message)
@abstractmethod
def filter(self, message):
"""This method must be overwritten.
Args:
message (:class:`telegram.Message`): The message that is tested.
Returns:
:obj:`dict` or :obj:`bool`
"""
class UpdateFilter(BaseFilter, ABC):
"""Base class for all Update Filters. In contrast to :class:`UpdateFilter`, the object
passed to :meth:`filter` is ``update``, which allows to create filters like
:attr:`Filters.update.edited_message`.
Please see :class:`telegram.ext.BaseFilter` for details on how to create custom filters.
Attributes:
name (:obj:`str`): Name for this filter. Defaults to the type of filter.
data_filter (:obj:`bool`): Whether this filter is a data filter. A data filter should
return a dict with lists. The dict will be merged with
:class:`telegram.ext.CallbackContext`'s internal dict in most cases
(depends on the handler).
"""
def __call__(self, update):
return self.filter(update)
@abstractmethod @abstractmethod
def filter(self, update): def filter(self, update):
"""This method must be overwritten. """This method must be overwritten.
Note:
If :attr:`update_filter` is :obj:`False` then the first argument is `message` and of
type :class:`telegram.Message`.
Args: Args:
update (:class:`telegram.Update`): The update that is tested. update (:class:`telegram.Update`): The update that is tested.
@ -121,15 +165,13 @@ class BaseFilter(ABC):
""" """
class InvertedFilter(BaseFilter): class InvertedFilter(UpdateFilter):
"""Represents a filter that has been inverted. """Represents a filter that has been inverted.
Args: Args:
f: The filter to invert. f: The filter to invert.
""" """
update_filter = True
def __init__(self, f): def __init__(self, f):
self.f = f self.f = f
@ -140,7 +182,7 @@ class InvertedFilter(BaseFilter):
return "<inverted {}>".format(self.f) return "<inverted {}>".format(self.f)
class MergedFilter(BaseFilter): class MergedFilter(UpdateFilter):
"""Represents a filter consisting of two other filters. """Represents a filter consisting of two other filters.
Args: Args:
@ -149,8 +191,6 @@ class MergedFilter(BaseFilter):
or_filter: Optional filter to "or" with base_filter. Mutually exclusive with and_filter. or_filter: Optional filter to "or" with base_filter. Mutually exclusive with and_filter.
""" """
update_filter = True
def __init__(self, base_filter, and_filter=None, or_filter=None): def __init__(self, base_filter, and_filter=None, or_filter=None):
self.base_filter = base_filter self.base_filter = base_filter
if self.base_filter.data_filter: if self.base_filter.data_filter:
@ -215,13 +255,13 @@ class MergedFilter(BaseFilter):
self.and_filter or self.or_filter) self.and_filter or self.or_filter)
class _DiceEmoji(BaseFilter): class _DiceEmoji(MessageFilter):
def __init__(self, emoji=None, name=None): def __init__(self, emoji=None, name=None):
self.name = 'Filters.dice.{}'.format(name) if name else 'Filters.dice' self.name = 'Filters.dice.{}'.format(name) if name else 'Filters.dice'
self.emoji = emoji self.emoji = emoji
class _DiceValues(BaseFilter): class _DiceValues(MessageFilter):
def __init__(self, values, name, emoji=None): def __init__(self, values, name, emoji=None):
self.values = [values] if isinstance(values, int) else values self.values = [values] if isinstance(values, int) else values
@ -248,7 +288,8 @@ class _DiceEmoji(BaseFilter):
class Filters: class Filters:
"""Predefined filters for use as the `filter` argument of :class:`telegram.ext.MessageHandler`. """Predefined filters for use as the ``filter`` argument of
:class:`telegram.ext.MessageHandler`.
Examples: Examples:
Use ``MessageHandler(Filters.video, callback_method)`` to filter all video Use ``MessageHandler(Filters.video, callback_method)`` to filter all video
@ -256,7 +297,7 @@ class Filters:
""" """
class _All(BaseFilter): class _All(MessageFilter):
name = 'Filters.all' name = 'Filters.all'
def filter(self, message): def filter(self, message):
@ -265,10 +306,10 @@ class Filters:
all = _All() all = _All()
"""All Messages.""" """All Messages."""
class _Text(BaseFilter): class _Text(MessageFilter):
name = 'Filters.text' name = 'Filters.text'
class _TextStrings(BaseFilter): class _TextStrings(MessageFilter):
def __init__(self, strings): def __init__(self, strings):
self.strings = strings self.strings = strings
@ -316,10 +357,10 @@ class Filters:
exact matches are allowed. If not specified, will allow any text message. exact matches are allowed. If not specified, will allow any text message.
""" """
class _Caption(BaseFilter): class _Caption(MessageFilter):
name = 'Filters.caption' name = 'Filters.caption'
class _CaptionStrings(BaseFilter): class _CaptionStrings(MessageFilter):
def __init__(self, strings): def __init__(self, strings):
self.strings = strings self.strings = strings
@ -351,10 +392,10 @@ class Filters:
exact matches are allowed. If not specified, will allow any message with a caption. exact matches are allowed. If not specified, will allow any message with a caption.
""" """
class _Command(BaseFilter): class _Command(MessageFilter):
name = 'Filters.command' name = 'Filters.command'
class _CommandOnlyStart(BaseFilter): class _CommandOnlyStart(MessageFilter):
def __init__(self, only_start): def __init__(self, only_start):
self.only_start = only_start self.only_start = only_start
@ -393,7 +434,7 @@ class Filters:
command. Defaults to :obj:`True`. command. Defaults to :obj:`True`.
""" """
class regex(BaseFilter): class regex(MessageFilter):
""" """
Filters updates by searching for an occurrence of ``pattern`` in the message text. Filters updates by searching for an occurrence of ``pattern`` in the message text.
The ``re.search()`` function is used to determine whether an update should be filtered. The ``re.search()`` function is used to determine whether an update should be filtered.
@ -438,7 +479,7 @@ class Filters:
return {'matches': [match]} return {'matches': [match]}
return {} return {}
class _Reply(BaseFilter): class _Reply(MessageFilter):
name = 'Filters.reply' name = 'Filters.reply'
def filter(self, message): def filter(self, message):
@ -447,7 +488,7 @@ class Filters:
reply = _Reply() reply = _Reply()
"""Messages that are a reply to another message.""" """Messages that are a reply to another message."""
class _Audio(BaseFilter): class _Audio(MessageFilter):
name = 'Filters.audio' name = 'Filters.audio'
def filter(self, message): def filter(self, message):
@ -456,10 +497,10 @@ class Filters:
audio = _Audio() audio = _Audio()
"""Messages that contain :class:`telegram.Audio`.""" """Messages that contain :class:`telegram.Audio`."""
class _Document(BaseFilter): class _Document(MessageFilter):
name = 'Filters.document' name = 'Filters.document'
class category(BaseFilter): class category(MessageFilter):
"""Filters documents by their category in the mime-type attribute. """Filters documents by their category in the mime-type attribute.
Note: Note:
@ -492,7 +533,7 @@ class Filters:
video = category('video/') video = category('video/')
text = category('text/') text = category('text/')
class mime_type(BaseFilter): class mime_type(MessageFilter):
"""This Filter filters documents by their mime-type attribute """This Filter filters documents by their mime-type attribute
Note: Note:
@ -592,7 +633,7 @@ officedocument.wordprocessingml.document")``-
zip: Same as ``Filters.document.mime_type("application/zip")``- zip: Same as ``Filters.document.mime_type("application/zip")``-
""" """
class _Animation(BaseFilter): class _Animation(MessageFilter):
name = 'Filters.animation' name = 'Filters.animation'
def filter(self, message): def filter(self, message):
@ -601,7 +642,7 @@ officedocument.wordprocessingml.document")``-
animation = _Animation() animation = _Animation()
"""Messages that contain :class:`telegram.Animation`.""" """Messages that contain :class:`telegram.Animation`."""
class _Photo(BaseFilter): class _Photo(MessageFilter):
name = 'Filters.photo' name = 'Filters.photo'
def filter(self, message): def filter(self, message):
@ -610,7 +651,7 @@ officedocument.wordprocessingml.document")``-
photo = _Photo() photo = _Photo()
"""Messages that contain :class:`telegram.PhotoSize`.""" """Messages that contain :class:`telegram.PhotoSize`."""
class _Sticker(BaseFilter): class _Sticker(MessageFilter):
name = 'Filters.sticker' name = 'Filters.sticker'
def filter(self, message): def filter(self, message):
@ -619,7 +660,7 @@ officedocument.wordprocessingml.document")``-
sticker = _Sticker() sticker = _Sticker()
"""Messages that contain :class:`telegram.Sticker`.""" """Messages that contain :class:`telegram.Sticker`."""
class _Video(BaseFilter): class _Video(MessageFilter):
name = 'Filters.video' name = 'Filters.video'
def filter(self, message): def filter(self, message):
@ -628,7 +669,7 @@ officedocument.wordprocessingml.document")``-
video = _Video() video = _Video()
"""Messages that contain :class:`telegram.Video`.""" """Messages that contain :class:`telegram.Video`."""
class _Voice(BaseFilter): class _Voice(MessageFilter):
name = 'Filters.voice' name = 'Filters.voice'
def filter(self, message): def filter(self, message):
@ -637,7 +678,7 @@ officedocument.wordprocessingml.document")``-
voice = _Voice() voice = _Voice()
"""Messages that contain :class:`telegram.Voice`.""" """Messages that contain :class:`telegram.Voice`."""
class _VideoNote(BaseFilter): class _VideoNote(MessageFilter):
name = 'Filters.video_note' name = 'Filters.video_note'
def filter(self, message): def filter(self, message):
@ -646,7 +687,7 @@ officedocument.wordprocessingml.document")``-
video_note = _VideoNote() video_note = _VideoNote()
"""Messages that contain :class:`telegram.VideoNote`.""" """Messages that contain :class:`telegram.VideoNote`."""
class _Contact(BaseFilter): class _Contact(MessageFilter):
name = 'Filters.contact' name = 'Filters.contact'
def filter(self, message): def filter(self, message):
@ -655,7 +696,7 @@ officedocument.wordprocessingml.document")``-
contact = _Contact() contact = _Contact()
"""Messages that contain :class:`telegram.Contact`.""" """Messages that contain :class:`telegram.Contact`."""
class _Location(BaseFilter): class _Location(MessageFilter):
name = 'Filters.location' name = 'Filters.location'
def filter(self, message): def filter(self, message):
@ -664,7 +705,7 @@ officedocument.wordprocessingml.document")``-
location = _Location() location = _Location()
"""Messages that contain :class:`telegram.Location`.""" """Messages that contain :class:`telegram.Location`."""
class _Venue(BaseFilter): class _Venue(MessageFilter):
name = 'Filters.venue' name = 'Filters.venue'
def filter(self, message): def filter(self, message):
@ -673,7 +714,7 @@ officedocument.wordprocessingml.document")``-
venue = _Venue() venue = _Venue()
"""Messages that contain :class:`telegram.Venue`.""" """Messages that contain :class:`telegram.Venue`."""
class _StatusUpdate(BaseFilter): class _StatusUpdate(UpdateFilter):
"""Subset for messages containing a status update. """Subset for messages containing a status update.
Examples: Examples:
@ -681,9 +722,7 @@ officedocument.wordprocessingml.document")``-
``Filters.status_update`` for all status update messages. ``Filters.status_update`` for all status update messages.
""" """
update_filter = True class _NewChatMembers(MessageFilter):
class _NewChatMembers(BaseFilter):
name = 'Filters.status_update.new_chat_members' name = 'Filters.status_update.new_chat_members'
def filter(self, message): def filter(self, message):
@ -692,7 +731,7 @@ officedocument.wordprocessingml.document")``-
new_chat_members = _NewChatMembers() new_chat_members = _NewChatMembers()
"""Messages that contain :attr:`telegram.Message.new_chat_members`.""" """Messages that contain :attr:`telegram.Message.new_chat_members`."""
class _LeftChatMember(BaseFilter): class _LeftChatMember(MessageFilter):
name = 'Filters.status_update.left_chat_member' name = 'Filters.status_update.left_chat_member'
def filter(self, message): def filter(self, message):
@ -701,7 +740,7 @@ officedocument.wordprocessingml.document")``-
left_chat_member = _LeftChatMember() left_chat_member = _LeftChatMember()
"""Messages that contain :attr:`telegram.Message.left_chat_member`.""" """Messages that contain :attr:`telegram.Message.left_chat_member`."""
class _NewChatTitle(BaseFilter): class _NewChatTitle(MessageFilter):
name = 'Filters.status_update.new_chat_title' name = 'Filters.status_update.new_chat_title'
def filter(self, message): def filter(self, message):
@ -710,7 +749,7 @@ officedocument.wordprocessingml.document")``-
new_chat_title = _NewChatTitle() new_chat_title = _NewChatTitle()
"""Messages that contain :attr:`telegram.Message.new_chat_title`.""" """Messages that contain :attr:`telegram.Message.new_chat_title`."""
class _NewChatPhoto(BaseFilter): class _NewChatPhoto(MessageFilter):
name = 'Filters.status_update.new_chat_photo' name = 'Filters.status_update.new_chat_photo'
def filter(self, message): def filter(self, message):
@ -719,7 +758,7 @@ officedocument.wordprocessingml.document")``-
new_chat_photo = _NewChatPhoto() new_chat_photo = _NewChatPhoto()
"""Messages that contain :attr:`telegram.Message.new_chat_photo`.""" """Messages that contain :attr:`telegram.Message.new_chat_photo`."""
class _DeleteChatPhoto(BaseFilter): class _DeleteChatPhoto(MessageFilter):
name = 'Filters.status_update.delete_chat_photo' name = 'Filters.status_update.delete_chat_photo'
def filter(self, message): def filter(self, message):
@ -728,7 +767,7 @@ officedocument.wordprocessingml.document")``-
delete_chat_photo = _DeleteChatPhoto() delete_chat_photo = _DeleteChatPhoto()
"""Messages that contain :attr:`telegram.Message.delete_chat_photo`.""" """Messages that contain :attr:`telegram.Message.delete_chat_photo`."""
class _ChatCreated(BaseFilter): class _ChatCreated(MessageFilter):
name = 'Filters.status_update.chat_created' name = 'Filters.status_update.chat_created'
def filter(self, message): def filter(self, message):
@ -740,7 +779,7 @@ officedocument.wordprocessingml.document")``-
:attr: `telegram.Message.supergroup_chat_created` or :attr: `telegram.Message.supergroup_chat_created` or
:attr: `telegram.Message.channel_chat_created`.""" :attr: `telegram.Message.channel_chat_created`."""
class _Migrate(BaseFilter): class _Migrate(MessageFilter):
name = 'Filters.status_update.migrate' name = 'Filters.status_update.migrate'
def filter(self, message): def filter(self, message):
@ -750,7 +789,7 @@ officedocument.wordprocessingml.document")``-
"""Messages that contain :attr:`telegram.Message.migrate_from_chat_id` or """Messages that contain :attr:`telegram.Message.migrate_from_chat_id` or
:attr:`telegram.Message.migrate_to_chat_id`.""" :attr:`telegram.Message.migrate_to_chat_id`."""
class _PinnedMessage(BaseFilter): class _PinnedMessage(MessageFilter):
name = 'Filters.status_update.pinned_message' name = 'Filters.status_update.pinned_message'
def filter(self, message): def filter(self, message):
@ -759,7 +798,7 @@ officedocument.wordprocessingml.document")``-
pinned_message = _PinnedMessage() pinned_message = _PinnedMessage()
"""Messages that contain :attr:`telegram.Message.pinned_message`.""" """Messages that contain :attr:`telegram.Message.pinned_message`."""
class _ConnectedWebsite(BaseFilter): class _ConnectedWebsite(MessageFilter):
name = 'Filters.status_update.connected_website' name = 'Filters.status_update.connected_website'
def filter(self, message): def filter(self, message):
@ -806,7 +845,7 @@ officedocument.wordprocessingml.document")``-
:attr:`telegram.Message.pinned_message`. :attr:`telegram.Message.pinned_message`.
""" """
class _Forwarded(BaseFilter): class _Forwarded(MessageFilter):
name = 'Filters.forwarded' name = 'Filters.forwarded'
def filter(self, message): def filter(self, message):
@ -815,7 +854,7 @@ officedocument.wordprocessingml.document")``-
forwarded = _Forwarded() forwarded = _Forwarded()
"""Messages that are forwarded.""" """Messages that are forwarded."""
class _Game(BaseFilter): class _Game(MessageFilter):
name = 'Filters.game' name = 'Filters.game'
def filter(self, message): def filter(self, message):
@ -824,7 +863,7 @@ officedocument.wordprocessingml.document")``-
game = _Game() game = _Game()
"""Messages that contain :class:`telegram.Game`.""" """Messages that contain :class:`telegram.Game`."""
class entity(BaseFilter): class entity(MessageFilter):
""" """
Filters messages to only allow those which have a :class:`telegram.MessageEntity` Filters messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`. where their `type` matches `entity_type`.
@ -846,7 +885,7 @@ officedocument.wordprocessingml.document")``-
"""""" # remove method from docs """""" # remove method from docs
return any(entity.type == self.entity_type for entity in message.entities) return any(entity.type == self.entity_type for entity in message.entities)
class caption_entity(BaseFilter): class caption_entity(MessageFilter):
""" """
Filters media messages to only allow those which have a :class:`telegram.MessageEntity` Filters media messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`. where their `type` matches `entity_type`.
@ -868,7 +907,7 @@ officedocument.wordprocessingml.document")``-
"""""" # remove method from docs """""" # remove method from docs
return any(entity.type == self.entity_type for entity in message.caption_entities) return any(entity.type == self.entity_type for entity in message.caption_entities)
class _Private(BaseFilter): class _Private(MessageFilter):
name = 'Filters.private' name = 'Filters.private'
def filter(self, message): def filter(self, message):
@ -877,7 +916,7 @@ officedocument.wordprocessingml.document")``-
private = _Private() private = _Private()
"""Messages sent in a private chat.""" """Messages sent in a private chat."""
class _Group(BaseFilter): class _Group(MessageFilter):
name = 'Filters.group' name = 'Filters.group'
def filter(self, message): def filter(self, message):
@ -886,7 +925,7 @@ officedocument.wordprocessingml.document")``-
group = _Group() group = _Group()
"""Messages sent in a group chat.""" """Messages sent in a group chat."""
class user(BaseFilter): class user(MessageFilter):
"""Filters messages to allow only those which are from specified user ID(s) or """Filters messages to allow only those which are from specified user ID(s) or
username(s). username(s).
@ -1053,7 +1092,7 @@ officedocument.wordprocessingml.document")``-
return self.allow_empty return self.allow_empty
return False return False
class via_bot(BaseFilter): class via_bot(MessageFilter):
"""Filters messages to allow only those which are from specified via_bot ID(s) or """Filters messages to allow only those which are from specified via_bot ID(s) or
username(s). username(s).
@ -1220,7 +1259,7 @@ officedocument.wordprocessingml.document")``-
return self.allow_empty return self.allow_empty
return False return False
class chat(BaseFilter): class chat(MessageFilter):
"""Filters messages to allow only those which are from a specified chat ID or username. """Filters messages to allow only those which are from a specified chat ID or username.
Examples: Examples:
@ -1387,7 +1426,7 @@ officedocument.wordprocessingml.document")``-
return self.allow_empty return self.allow_empty
return False return False
class _Invoice(BaseFilter): class _Invoice(MessageFilter):
name = 'Filters.invoice' name = 'Filters.invoice'
def filter(self, message): def filter(self, message):
@ -1396,7 +1435,7 @@ officedocument.wordprocessingml.document")``-
invoice = _Invoice() invoice = _Invoice()
"""Messages that contain :class:`telegram.Invoice`.""" """Messages that contain :class:`telegram.Invoice`."""
class _SuccessfulPayment(BaseFilter): class _SuccessfulPayment(MessageFilter):
name = 'Filters.successful_payment' name = 'Filters.successful_payment'
def filter(self, message): def filter(self, message):
@ -1405,7 +1444,7 @@ officedocument.wordprocessingml.document")``-
successful_payment = _SuccessfulPayment() successful_payment = _SuccessfulPayment()
"""Messages that confirm a :class:`telegram.SuccessfulPayment`.""" """Messages that confirm a :class:`telegram.SuccessfulPayment`."""
class _PassportData(BaseFilter): class _PassportData(MessageFilter):
name = 'Filters.passport_data' name = 'Filters.passport_data'
def filter(self, message): def filter(self, message):
@ -1414,7 +1453,7 @@ officedocument.wordprocessingml.document")``-
passport_data = _PassportData() passport_data = _PassportData()
"""Messages that contain a :class:`telegram.PassportData`""" """Messages that contain a :class:`telegram.PassportData`"""
class _Poll(BaseFilter): class _Poll(MessageFilter):
name = 'Filters.poll' name = 'Filters.poll'
def filter(self, message): def filter(self, message):
@ -1457,7 +1496,7 @@ officedocument.wordprocessingml.document")``-
as for :attr:`Filters.dice`. as for :attr:`Filters.dice`.
""" """
class language(BaseFilter): class language(MessageFilter):
"""Filters messages to only allow those which are from users with a certain language code. """Filters messages to only allow those which are from users with a certain language code.
Note: Note:
@ -1486,48 +1525,42 @@ officedocument.wordprocessingml.document")``-
return message.from_user.language_code and any( return message.from_user.language_code and any(
[message.from_user.language_code.startswith(x) for x in self.lang]) [message.from_user.language_code.startswith(x) for x in self.lang])
class _UpdateType(BaseFilter): class _UpdateType(UpdateFilter):
update_filter = True
name = 'Filters.update' name = 'Filters.update'
class _Message(BaseFilter): class _Message(UpdateFilter):
name = 'Filters.update.message' name = 'Filters.update.message'
update_filter = True
def filter(self, update): def filter(self, update):
return update.message is not None return update.message is not None
message = _Message() message = _Message()
class _EditedMessage(BaseFilter): class _EditedMessage(UpdateFilter):
name = 'Filters.update.edited_message' name = 'Filters.update.edited_message'
update_filter = True
def filter(self, update): def filter(self, update):
return update.edited_message is not None return update.edited_message is not None
edited_message = _EditedMessage() edited_message = _EditedMessage()
class _Messages(BaseFilter): class _Messages(UpdateFilter):
name = 'Filters.update.messages' name = 'Filters.update.messages'
update_filter = True
def filter(self, update): def filter(self, update):
return update.message is not None or update.edited_message is not None return update.message is not None or update.edited_message is not None
messages = _Messages() messages = _Messages()
class _ChannelPost(BaseFilter): class _ChannelPost(UpdateFilter):
name = 'Filters.update.channel_post' name = 'Filters.update.channel_post'
update_filter = True
def filter(self, update): def filter(self, update):
return update.channel_post is not None return update.channel_post is not None
channel_post = _ChannelPost() channel_post = _ChannelPost()
class _EditedChannelPost(BaseFilter): class _EditedChannelPost(UpdateFilter):
update_filter = True
name = 'Filters.update.edited_channel_post' name = 'Filters.update.edited_channel_post'
def filter(self, update): def filter(self, update):
@ -1535,8 +1568,7 @@ officedocument.wordprocessingml.document")``-
edited_channel_post = _EditedChannelPost() edited_channel_post = _EditedChannelPost()
class _ChannelPosts(BaseFilter): class _ChannelPosts(UpdateFilter):
update_filter = True
name = 'Filters.update.channel_posts' name = 'Filters.update.channel_posts'
def filter(self, update): def filter(self, update):

View file

@ -25,7 +25,7 @@ class Venue(TelegramObject):
"""This object represents a venue. """This object represents a venue.
Objects of this class are comparable in terms of equality. Two objects of this class are Objects of this class are comparable in terms of equality. Two objects of this class are
considered equal, if their :attr:`location` and :attr:`title`are equal. considered equal, if their :attr:`location` and :attr:`title` are equal.
Attributes: Attributes:
location (:class:`telegram.Location`): Venue location. location (:class:`telegram.Location`): Venue location.

View file

@ -30,7 +30,7 @@ import pytz
from telegram import (Bot, Message, User, Chat, MessageEntity, Update, from telegram import (Bot, Message, User, Chat, MessageEntity, Update,
InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery, InlineQuery, CallbackQuery, ShippingQuery, PreCheckoutQuery,
ChosenInlineResult) ChosenInlineResult)
from telegram.ext import Dispatcher, JobQueue, Updater, BaseFilter, Defaults from telegram.ext import Dispatcher, JobQueue, Updater, MessageFilter, Defaults, UpdateFilter
from telegram.error import BadRequest from telegram.error import BadRequest
from tests.bots import get_bot from tests.bots import get_bot
@ -239,13 +239,18 @@ def make_command_update(message, edited=False, **kwargs):
return make_message_update(message, make_command_message, edited, **kwargs) return make_message_update(message, make_command_message, edited, **kwargs)
@pytest.fixture(scope='function') @pytest.fixture(scope='class',
def mock_filter(): params=[
class MockFilter(BaseFilter): {'class': MessageFilter},
{'class': UpdateFilter}
],
ids=['MessageFilter', 'UpdateFilter'])
def mock_filter(request):
class MockFilter(request.param['class']):
def __init__(self): def __init__(self):
self.tested = False self.tested = False
def filter(self, message): def filter(self, _):
self.tested = True self.tested = True
return MockFilter() return MockFilter()

View file

@ -21,7 +21,7 @@ import datetime
import pytest import pytest
from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice
from telegram.ext import Filters, BaseFilter from telegram.ext import Filters, BaseFilter, MessageFilter, UpdateFilter
import re import re
@ -37,6 +37,16 @@ def message_entity(request):
return MessageEntity(request.param, 0, 0, url='', user='') return MessageEntity(request.param, 0, 0, url='', user='')
@pytest.fixture(scope='class',
params=[
{'class': MessageFilter},
{'class': UpdateFilter}
],
ids=['MessageFilter', 'UpdateFilter'])
def base_class(request):
return request.param['class']
class TestFilters: class TestFilters:
def test_filters_all(self, update): def test_filters_all(self, update):
assert Filters.all(update) assert Filters.all(update)
@ -962,8 +972,8 @@ class TestFilters:
with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'): with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'):
_CustomFilter() _CustomFilter()
def test_custom_unnamed_filter(self, update): def test_custom_unnamed_filter(self, update, base_class):
class Unnamed(BaseFilter): class Unnamed(base_class):
def filter(self, mes): def filter(self, mes):
return True return True
@ -1009,14 +1019,14 @@ class TestFilters:
assert Filters.update.channel_posts(update) assert Filters.update.channel_posts(update)
assert Filters.update(update) assert Filters.update(update)
def test_merged_short_circuit_and(self, update): def test_merged_short_circuit_and(self, update, base_class):
update.message.text = '/test' update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
class TestException(Exception): class TestException(Exception):
pass pass
class RaisingFilter(BaseFilter): class RaisingFilter(base_class):
def filter(self, _): def filter(self, _):
raise TestException raise TestException
@ -1029,13 +1039,13 @@ class TestFilters:
update.message.entities = [] update.message.entities = []
(Filters.command & raising_filter)(update) (Filters.command & raising_filter)(update)
def test_merged_short_circuit_or(self, update): def test_merged_short_circuit_or(self, update, base_class):
update.message.text = 'test' update.message.text = 'test'
class TestException(Exception): class TestException(Exception):
pass pass
class RaisingFilter(BaseFilter): class RaisingFilter(base_class):
def filter(self, _): def filter(self, _):
raise TestException raise TestException
@ -1048,11 +1058,11 @@ class TestFilters:
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
(Filters.command | raising_filter)(update) (Filters.command | raising_filter)(update)
def test_merged_data_merging_and(self, update): def test_merged_data_merging_and(self, update, base_class):
update.message.text = '/test' update.message.text = '/test'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
class DataFilter(BaseFilter): class DataFilter(base_class):
data_filter = True data_filter = True
def __init__(self, data): def __init__(self, data):
@ -1072,10 +1082,10 @@ class TestFilters:
result = (Filters.command & DataFilter('blah'))(update) result = (Filters.command & DataFilter('blah'))(update)
assert not result assert not result
def test_merged_data_merging_or(self, update): def test_merged_data_merging_or(self, update, base_class):
update.message.text = '/test' update.message.text = '/test'
class DataFilter(BaseFilter): class DataFilter(base_class):
data_filter = True data_filter = True
def __init__(self, data): def __init__(self, data):

View file

@ -24,7 +24,7 @@ from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram import (Message, Update, Chat, Bot, User, CallbackQuery, InlineQuery, from telegram import (Message, Update, Chat, Bot, User, CallbackQuery, InlineQuery,
ChosenInlineResult, ShippingQuery, PreCheckoutQuery) ChosenInlineResult, ShippingQuery, PreCheckoutQuery)
from telegram.ext import Filters, MessageHandler, CallbackContext, JobQueue, BaseFilter from telegram.ext import Filters, MessageHandler, CallbackContext, JobQueue, UpdateFilter
message = Message(1, User(1, '', False), None, Chat(1, ''), text='Text') message = Message(1, User(1, '', False), None, Chat(1, ''), text='Text')
@ -163,8 +163,7 @@ class TestMessageHandler:
def test_callback_query_with_filter(self, message): def test_callback_query_with_filter(self, message):
class TestFilter(BaseFilter): class TestFilter(UpdateFilter):
update_filter = True
flag = False flag = False
def filter(self, u): def filter(self, u):