diff --git a/telegram/error.py b/telegram/error.py index 32b9da7ae..dc6b26be7 100644 --- a/telegram/error.py +++ b/telegram/error.py @@ -51,6 +51,9 @@ class TelegramError(Exception): def __str__(self): return '%s' % (self.message) + def __reduce__(self): + return self.__class__, (self.message,) + class Unauthorized(TelegramError): pass @@ -60,6 +63,9 @@ class InvalidToken(TelegramError): def __init__(self): super().__init__('Invalid token') + def __reduce__(self): + return self.__class__, () + class NetworkError(TelegramError): pass @@ -73,6 +79,9 @@ class TimedOut(NetworkError): def __init__(self): super().__init__('Timed out') + def __reduce__(self): + return self.__class__, () + class ChatMigrated(TelegramError): """ @@ -85,6 +94,9 @@ class ChatMigrated(TelegramError): super().__init__('Group migrated to supergroup. New chat id: {}'.format(new_chat_id)) self.new_chat_id = new_chat_id + def __reduce__(self): + return self.__class__, (self.new_chat_id,) + class RetryAfter(TelegramError): """ @@ -94,9 +106,12 @@ class RetryAfter(TelegramError): """ def __init__(self, retry_after): - super().__init__('Flood control exceeded. Retry in {} seconds'.format(retry_after)) + super().__init__('Flood control exceeded. Retry in {} seconds'.format(float(retry_after))) self.retry_after = float(retry_after) + def __reduce__(self): + return self.__class__, (self.retry_after,) + class Conflict(TelegramError): """ @@ -109,3 +124,6 @@ class Conflict(TelegramError): def __init__(self, msg): super().__init__(msg) + + def __reduce__(self): + return self.__class__, (self.message,) diff --git a/telegram/passport/credentials.py b/telegram/passport/credentials.py index 6981ccecc..52dc2f304 100644 --- a/telegram/passport/credentials.py +++ b/telegram/passport/credentials.py @@ -39,6 +39,10 @@ class TelegramDecryptionError(TelegramError): def __init__(self, message): super().__init__("TelegramDecryptionError: {}".format(message)) + self._msg = message + + def __reduce__(self): + return self.__class__, (self._msg,) def decrypt(secret, hash, data): diff --git a/tests/test_error.py b/tests/test_error.py index a20880248..65ab8dbc0 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -16,9 +16,12 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import pickle +from collections import defaultdict + import pytest -from telegram import TelegramError +from telegram import TelegramError, TelegramDecryptionError from telegram.error import Unauthorized, InvalidToken, NetworkError, BadRequest, TimedOut, \ ChatMigrated, RetryAfter, Conflict @@ -81,9 +84,53 @@ class TestErrors: assert e.new_chat_id == 1234 def test_retry_after(self): - with pytest.raises(RetryAfter, match="Flood control exceeded. Retry in 12 seconds"): + with pytest.raises(RetryAfter, match="Flood control exceeded. Retry in 12.0 seconds"): raise RetryAfter(12) def test_conflict(self): with pytest.raises(Conflict, match='Something something.'): raise Conflict('Something something.') + + @pytest.mark.parametrize( + "exception, attributes", + [ + (TelegramError("test message"), ["message"]), + (Unauthorized("test message"), ["message"]), + (InvalidToken(), ["message"]), + (NetworkError("test message"), ["message"]), + (BadRequest("test message"), ["message"]), + (TimedOut(), ["message"]), + (ChatMigrated(1234), ["message", "new_chat_id"]), + (RetryAfter(12), ["message", "retry_after"]), + (Conflict("test message"), ["message"]), + (TelegramDecryptionError("test message"), ["message"]) + ], + ) + def test_errors_pickling(self, exception, attributes): + pickled = pickle.dumps(exception) + unpickled = pickle.loads(pickled) + assert type(unpickled) is type(exception) + assert str(unpickled) == str(exception) + + for attribute in attributes: + assert getattr(unpickled, attribute) == getattr(exception, attribute) + + def test_pickling_test_coverage(self): + """ + This test is only here to make sure that new errors will override __reduce__ properly. + Add the new error class to the below covered_subclasses dict, if it's covered in the above + test_errors_pickling test. + """ + def make_assertion(cls): + assert {sc for sc in cls.__subclasses__()} == covered_subclasses[cls] + for subcls in cls.__subclasses__(): + make_assertion(subcls) + + covered_subclasses = defaultdict(set) + covered_subclasses.update({ + TelegramError: {Unauthorized, InvalidToken, NetworkError, ChatMigrated, RetryAfter, + Conflict, TelegramDecryptionError}, + NetworkError: {BadRequest, TimedOut} + }) + + make_assertion(TelegramError)