Improve Type Hints of Data Filters (#2456)

This commit is contained in:
Bibo-Joshi 2021-04-30 10:12:18 +02:00 committed by GitHub
parent 7e554584b1
commit 4645d0e32a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 9 deletions

View file

@ -53,6 +53,8 @@ __all__ = [
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import SLT
DataDict = Dict[str, list]
class BaseFilter(ABC):
"""Base class for all Filters.
@ -114,7 +116,7 @@ class BaseFilter(ABC):
data_filter = False
@abstractmethod
def __call__(self, update: Update) -> Optional[Union[bool, Dict]]:
def __call__(self, update: Update) -> Optional[Union[bool, DataDict]]:
pass
def __and__(self, other: 'BaseFilter') -> 'BaseFilter':
@ -160,11 +162,11 @@ class MessageFilter(BaseFilter, ABC):
"""
def __call__(self, update: Update) -> Optional[Union[bool, Dict]]:
def __call__(self, update: Update) -> Optional[Union[bool, DataDict]]:
return self.filter(update.effective_message)
@abstractmethod
def filter(self, message: Message) -> Optional[Union[bool, Dict]]:
def filter(self, message: Message) -> Optional[Union[bool, DataDict]]:
"""This method must be overwritten.
Args:
@ -193,11 +195,11 @@ class UpdateFilter(BaseFilter, ABC):
"""
def __call__(self, update: Update) -> Optional[Union[bool, Dict]]:
def __call__(self, update: Update) -> Optional[Union[bool, DataDict]]:
return self.filter(update)
@abstractmethod
def filter(self, update: Update) -> Optional[Union[bool, Dict]]:
def filter(self, update: Update) -> Optional[Union[bool, DataDict]]:
"""This method must be overwritten.
Args:
@ -260,7 +262,7 @@ class MergedFilter(UpdateFilter):
self.data_filter = True
@staticmethod
def _merge(base_output: Union[bool, Dict], comp_output: Union[bool, Dict]) -> Dict:
def _merge(base_output: Union[bool, Dict], comp_output: Union[bool, Dict]) -> DataDict:
base = base_output if isinstance(base_output, dict) else {}
comp = comp_output if isinstance(comp_output, dict) else {}
for k in comp.keys():
@ -276,7 +278,7 @@ class MergedFilter(UpdateFilter):
base[k] = comp_value
return base
def filter(self, update: Update) -> Union[bool, Dict]: # pylint: disable=R0911
def filter(self, update: Update) -> Union[bool, DataDict]: # pylint: disable=R0911
base_output = self.base_filter(update)
# We need to check if the filters are data filters and if so return the merged data.
# If it's not a data filter or an or_filter but no matches return bool
@ -331,7 +333,7 @@ class XORFilter(UpdateFilter):
self.xor_filter = xor_filter
self.merged_filter = (base_filter & ~xor_filter) | (~base_filter & xor_filter)
def filter(self, update: Update) -> Optional[Union[bool, Dict]]:
def filter(self, update: Update) -> Optional[Union[bool, DataDict]]:
return self.merged_filter(update)
@property

View file

@ -179,7 +179,7 @@ class MessageHandler(Handler[Update]):
Filters.update.edited_message | Filters.update.edited_channel_post
)
def check_update(self, update: object) -> Optional[Union[bool, Dict[str, object]]]:
def check_update(self, update: object) -> Optional[Union[bool, Dict[str, list]]]:
"""Determines whether an update should be passed to this handlers :attr:`callback`.
Args: