Merge pull request #677 from evgfilim1/new-filters

New filters for handling messages from specific chat/user id
This commit is contained in:
Noam Meltzer 2017-06-22 21:25:39 +03:00 committed by GitHub
commit 470ee86497
2 changed files with 118 additions and 0 deletions

View file

@ -359,6 +359,79 @@ class Filters(object):
group = _Group()
class user(BaseFilter):
"""Filters messages to allow only those which are from specified user ID.
Notes:
Only one of chat_id or username must be used here.
Args:
user_id(Optional[int|list]): which user ID(s) to allow through.
username(Optional[str|list]): which username(s) to allow through. If username starts
with '@' symbol, it will be ignored.
Raises:
ValueError
"""
def __init__(self, user_id=None, username=None):
if not (bool(user_id) ^ bool(username)):
raise ValueError('One and only one of user_id or username must be used')
if user_id is not None and isinstance(user_id, int):
self.user_ids = [user_id]
else:
self.user_ids = user_id
if username is None:
self.usernames = username
elif isinstance(username, str_type):
self.usernames = [username.replace('@', '')]
else:
self.usernames = [user.replace('@', '') for user in username]
def filter(self, message):
if self.user_ids is not None:
return bool(message.from_user and message.from_user.id in self.user_ids)
else:
# self.usernames is not None
return bool(message.from_user and message.from_user.username and
message.from_user.username in self.usernames)
class chat(BaseFilter):
"""Filters messages to allow only those which are from specified chat ID.
Notes:
Only one of chat_id or username must be used here.
Args:
chat_id(Optional[int|list]): which chat ID(s) to allow through.
username(Optional[str|list]): which username(s) to allow through. If username starts
with '@' symbol, it will be ignored.
Raises:
ValueError
"""
def __init__(self, chat_id=None, username=None):
if not (bool(chat_id) ^ bool(username)):
raise ValueError('One and only one of chat_id or username must be used')
if chat_id is not None and isinstance(chat_id, int):
self.chat_ids = [chat_id]
else:
self.chat_ids = chat_id
if username is None:
self.usernames = username
elif isinstance(username, str_type):
self.usernames = [username.replace('@', '')]
else:
self.usernames = [chat.replace('@', '') for chat in username]
def filter(self, message):
if self.chat_ids is not None:
return bool(message.chat_id in self.chat_ids)
else:
# self.usernames is not None
return bool(message.chat.username and message.chat.username in self.usernames)
class _Invoice(BaseFilter):
name = 'Filters.invoice'

View file

@ -213,6 +213,51 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.chat.type = "supergroup"
self.assertTrue(Filters.group(self.message))
def test_filters_chat(self):
with self.assertRaisesRegexp(ValueError, 'chat_id or username'):
Filters.chat(chat_id=-1, username='chat')
with self.assertRaisesRegexp(ValueError, 'chat_id or username'):
Filters.chat()
def test_filters_chat_id(self):
self.assertFalse(Filters.chat(chat_id=-1)(self.message))
self.message.chat.id = -1
self.assertTrue(Filters.chat(chat_id=-1)(self.message))
self.message.chat.id = -2
self.assertTrue(Filters.chat(chat_id=[-1, -2])(self.message))
self.assertFalse(Filters.chat(chat_id=-1)(self.message))
def test_filters_chat_username(self):
self.assertFalse(Filters.chat(username='chat')(self.message))
self.message.chat.username = 'chat'
self.assertTrue(Filters.chat(username='@chat')(self.message))
self.assertTrue(Filters.chat(username='chat')(self.message))
self.assertTrue(Filters.chat(username=['chat1', 'chat', 'chat2'])(self.message))
self.assertFalse(Filters.chat(username=['@chat1', 'chat_2'])(self.message))
def test_filters_user(self):
with self.assertRaisesRegexp(ValueError, 'user_id or username'):
Filters.user(user_id=1, username='user')
with self.assertRaisesRegexp(ValueError, 'user_id or username'):
Filters.user()
def test_filters_user_id(self):
self.assertFalse(Filters.user(user_id=1)(self.message))
self.message.from_user.id = 1
self.assertTrue(Filters.user(user_id=1)(self.message))
self.message.from_user.id = 2
self.assertTrue(Filters.user(user_id=[1, 2])(self.message))
self.assertFalse(Filters.user(user_id=1)(self.message))
def test_filters_username(self):
self.assertFalse(Filters.user(username='user')(self.message))
self.assertFalse(Filters.user(username='Testuser')(self.message))
self.message.from_user.username = 'user'
self.assertTrue(Filters.user(username='@user')(self.message))
self.assertTrue(Filters.user(username='user')(self.message))
self.assertTrue(Filters.user(username=['user1', 'user', 'user2'])(self.message))
self.assertFalse(Filters.user(username=['@username', '@user_2'])(self.message))
def test_and_filters(self):
self.message.text = 'test'
self.message.forward_date = True