Use @abstractmethod instead of raising NotImplementedError (#1905)

This commit is contained in:
Bibo-Joshi 2020-05-01 20:27:34 +02:00 committed by GitHub
parent 76567ba635
commit 632b989d90
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 92 additions and 42 deletions

View file

@ -23,13 +23,10 @@ try:
except ImportError:
import json
from abc import ABCMeta
class TelegramObject(object):
"""Base class for most telegram objects."""
__metaclass__ = ABCMeta
_id_attrs = ()
def __str__(self):

View file

@ -18,8 +18,10 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""
from abc import ABC, abstractmethod
class BasePersistence(object):
class BasePersistence(ABC):
"""Interface class for adding persistence to your bot.
Subclass this object for different implementations of a persistent bot.
@ -57,6 +59,7 @@ class BasePersistence(object):
self.store_chat_data = store_chat_data
self.store_bot_data = store_bot_data
@abstractmethod
def get_user_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the user_data if stored, or an empty
@ -65,8 +68,8 @@ class BasePersistence(object):
Returns:
:obj:`defaultdict`: The restored user data.
"""
raise NotImplementedError
@abstractmethod
def get_chat_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the chat_data if stored, or an empty
@ -75,8 +78,8 @@ class BasePersistence(object):
Returns:
:obj:`defaultdict`: The restored chat data.
"""
raise NotImplementedError
@abstractmethod
def get_bot_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the bot_data if stored, or an empty
@ -85,8 +88,8 @@ class BasePersistence(object):
Returns:
:obj:`defaultdict`: The restored bot data.
"""
raise NotImplementedError
@abstractmethod
def get_conversations(self, name):
""""Will be called by :class:`telegram.ext.Dispatcher` when a
:class:`telegram.ext.ConversationHandler` is added if
@ -99,8 +102,8 @@ class BasePersistence(object):
Returns:
:obj:`dict`: The restored conversations for the handler.
"""
raise NotImplementedError
@abstractmethod
def update_conversation(self, name, key, new_state):
"""Will be called when a :attr:`telegram.ext.ConversationHandler.update_state`
is called. this allows the storeage of the new state in the persistence.
@ -110,8 +113,8 @@ class BasePersistence(object):
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
raise NotImplementedError
@abstractmethod
def update_user_data(self, user_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
@ -120,8 +123,8 @@ class BasePersistence(object):
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
"""
raise NotImplementedError
@abstractmethod
def update_chat_data(self, chat_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
@ -130,8 +133,8 @@ class BasePersistence(object):
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
"""
raise NotImplementedError
@abstractmethod
def update_bot_data(self, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
@ -139,7 +142,6 @@ class BasePersistence(object):
Args:
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data` .
"""
raise NotImplementedError
def flush(self):
"""Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the

View file

@ -21,13 +21,14 @@
import re
from future.utils import string_types
from abc import ABC, abstractmethod
from telegram import Chat, Update, MessageEntity
__all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter']
class BaseFilter(object):
class BaseFilter(ABC):
"""Base class for all Message Filters.
Subclassing from this class filters to be combined using bitwise operators:
@ -103,6 +104,7 @@ class BaseFilter(object):
self.name = self.__class__.__name__
return self.name
@abstractmethod
def filter(self, update):
"""This method must be overwritten.
@ -118,8 +120,6 @@ class BaseFilter(object):
"""
raise NotImplementedError
class InvertedFilter(BaseFilter):
"""Represents a filter that has been inverted.

View file

@ -18,8 +18,10 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the base class for handlers as used by the Dispatcher."""
from abc import ABC, abstractmethod
class Handler(object):
class Handler(ABC):
"""The base class for all update handlers. Create custom handlers by inheriting from it.
Attributes:
@ -82,6 +84,7 @@ class Handler(object):
self.pass_user_data = pass_user_data
self.pass_chat_data = pass_chat_data
@abstractmethod
def check_update(self, update):
"""
This method is called to determine if an update should be handled by
@ -96,7 +99,6 @@ class Handler(object):
when the update gets handled.
"""
raise NotImplementedError
def handle_update(self, update, dispatcher, check_result, context=None):
"""

View file

@ -373,6 +373,12 @@ class TestDispatcher(object):
def update_user_data(self, user_id, data):
raise Exception
def get_conversations(self, name):
pass
def update_conversation(self, name, key, new_state):
pass
def start1(b, u):
pass
@ -470,6 +476,21 @@ class TestDispatcher(object):
def update_user_data(self, user_id, data):
self.update(data)
def get_chat_data(self):
pass
def get_bot_data(self):
pass
def get_user_data(self):
pass
def get_conversations(self, name):
pass
def update_conversation(self, name, key, new_state):
pass
def callback(update, context):
pass
@ -513,6 +534,21 @@ class TestDispatcher(object):
def update_user_data(self, user_id, data):
self.test_flag_user_data = True
def update_conversation(self, name, key, new_state):
pass
def get_conversations(self, name):
pass
def get_user_data(self):
pass
def get_bot_data(self):
pass
def get_chat_data(self):
pass
def callback(update, context):
pass

View file

@ -730,10 +730,8 @@ class TestFilters(object):
class _CustomFilter(BaseFilter):
pass
custom = _CustomFilter()
with pytest.raises(NotImplementedError):
(custom & Filters.text)(update)
with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'):
_CustomFilter()
def test_custom_unnamed_filter(self, update):
class Unnamed(BaseFilter):

View file

@ -51,7 +51,33 @@ def change_directory(tmp_path):
@pytest.fixture(scope="function")
def base_persistence():
return BasePersistence(store_chat_data=True, store_user_data=True, store_bot_data=True)
class OwnPersistence(BasePersistence):
def get_bot_data(self):
raise NotImplementedError
def get_chat_data(self):
raise NotImplementedError
def get_user_data(self):
raise NotImplementedError
def get_conversations(self, name):
raise NotImplementedError
def update_bot_data(self, data):
raise NotImplementedError
def update_chat_data(self, chat_id, data):
raise NotImplementedError
def update_conversation(self, name, key, new_state):
raise NotImplementedError
def update_user_data(self, user_id, data):
raise NotImplementedError
return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True)
@pytest.fixture(scope="function")
@ -100,22 +126,13 @@ class TestBasePersistence(object):
def test_creation(self, base_persistence):
assert base_persistence.store_chat_data
assert base_persistence.store_user_data
with pytest.raises(NotImplementedError):
base_persistence.get_bot_data()
with pytest.raises(NotImplementedError):
base_persistence.get_chat_data()
with pytest.raises(NotImplementedError):
base_persistence.get_user_data()
with pytest.raises(NotImplementedError):
base_persistence.get_conversations("test")
with pytest.raises(NotImplementedError):
base_persistence.update_bot_data(None)
with pytest.raises(NotImplementedError):
base_persistence.update_chat_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_user_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_conversation(None, None, None)
assert base_persistence.store_bot_data
def test_abstract_methods(self):
with pytest.raises(TypeError, match=('get_bot_data, get_chat_data, get_conversations, '
'get_user_data, update_bot_data, update_chat_data, '
'update_conversation, update_user_data')):
BasePersistence()
def test_implementation(self, updater, base_persistence):
dp = updater.dispatcher
@ -127,8 +144,6 @@ class TestBasePersistence(object):
with pytest.raises(ValueError, match="if dispatcher has no persistence"):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
dp.persistence = base_persistence
with pytest.raises(NotImplementedError):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data,
bot_data):

View file

@ -41,7 +41,7 @@ from future.builtins import bytes
from telegram import TelegramError, Message, User, Chat, Update, Bot
from telegram.error import Unauthorized, InvalidToken, TimedOut, RetryAfter
from telegram.ext import Updater, Dispatcher, BasePersistence
from telegram.ext import Updater, Dispatcher, DictPersistence
signalskip = pytest.mark.skipif(sys.platform == 'win32',
reason='Can\'t send signals without stopping '
@ -467,7 +467,7 @@ class TestUpdater(object):
def test_mutual_exclude_persistence_dispatcher(self):
dispatcher = Dispatcher(None, None)
persistence = BasePersistence()
persistence = DictPersistence()
with pytest.raises(ValueError):
Updater(dispatcher=dispatcher, persistence=persistence)