__call__ should return the result

Also add tests with both & and |.
This commit is contained in:
Jacob Bom 2016-09-24 18:56:54 +02:00
parent be0f5bc519
commit 61596400e1
2 changed files with 27 additions and 14 deletions

View file

@ -23,7 +23,7 @@ class BaseFilter(object):
"""Base class for all Message Filters"""
def __call__(self, message):
self.filter(message)
return self.filter(message)
def __and__(self, other):
return MergedFilter(self, and_filter=other)

View file

@ -23,14 +23,11 @@ This module contains a object that represents Tests for MessageHandler.Filters
import sys
import unittest
from datetime import datetime
import functools
from telegram import MessageEntity
sys.path.append('.')
from telegram import Message, User, Chat
from telegram import Message, User, Chat, MessageEntity
from telegram.ext import Filters
from tests.base import BaseTest
@ -40,6 +37,7 @@ class FiltersTest(BaseTest, unittest.TestCase):
def setUp(self):
self.message = Message(0, User(0, "Testuser"), datetime.now(), Chat(0, 'private'))
self.e = functools.partial(MessageEntity, offset=0, length=0)
def test_filters_text(self):
self.message.text = 'test'
@ -155,23 +153,19 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.pinned_message = None
def test_entities_filter(self):
e = functools.partial(MessageEntity, offset=0, length=0)
self.message.entities = [e(MessageEntity.MENTION)]
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = []
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = [e(MessageEntity.BOLD)]
self.message.entities = [self.e(MessageEntity.BOLD)]
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = [e(MessageEntity.BOLD), e(MessageEntity.MENTION)]
self.message.entities = [self.e(MessageEntity.BOLD), self.e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
def test_and_filters(self):
# For now just test with forwarded as that's the only one that makes sense
# That'll change when we get a entities filter
self.message.text = 'test'
self.message.forward_date = True
self.assertTrue((Filters.text & Filters.forwarded)(self.message))
@ -181,9 +175,16 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.forward_date = None
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
self.message.text = 'test'
self.message.forward_date = True
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue((Filters.text & Filters.forwarded & Filters.entity(MessageEntity.MENTION))(
self.message))
self.message.entities = [self.e(MessageEntity.BOLD)]
self.assertFalse((Filters.text & Filters.forwarded & Filters.entity(MessageEntity.MENTION)
)(self.message))
def test_or_filters(self):
# For now just test with forwarded as that's the only one that makes sense
# That'll change when we get a entities filter
self.message.text = 'test'
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.group_chat_created = True
@ -193,6 +194,18 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.group_chat_created = False
self.assertFalse((Filters.text | Filters.status_update)(self.message))
def test_and_or_filters(self):
self.message.text = 'test'
self.message.forward_date = True
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))
self.message.forward_date = False
self.assertFalse((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)
))(self.message))
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))
if __name__ == '__main__':
unittest.main()