Drop Manual Token Validation (#3167)

This commit is contained in:
Harshil 2022-08-03 11:46:48 +05:30 committed by GitHub
parent c28ad86214
commit 143db5fc9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 51 deletions

View file

@ -197,7 +197,9 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
private_key: bytes = None, private_key: bytes = None,
private_key_password: bytes = None, private_key_password: bytes = None,
): ):
self.token = self._validate_token(token) if not token:
raise InvalidToken("You must pass the token you received from https://t.me/Botfather!")
self.token = token
self.base_url = base_url + self.token self.base_url = base_url + self.token
self.base_file_url = base_file_url + self.token self.base_file_url = base_file_url + self.token
@ -372,7 +374,12 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
return return
await asyncio.gather(self._request[0].initialize(), self._request[1].initialize()) await asyncio.gather(self._request[0].initialize(), self._request[1].initialize())
await self.get_me() # Since the bot is to be initialized only once, we can also use it for
# verifying the token passed and raising an exception if it's invalid.
try:
await self.get_me()
except InvalidToken as exc:
raise InvalidToken(f"The token `{self.token}` was rejected by the server.") from exc
self._initialized = True self._initialized = True
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -418,18 +425,6 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
""" """
return self._request[1] return self._request[1]
@staticmethod
def _validate_token(token: str) -> str:
"""A very basic validation on token."""
if any(x.isspace() for x in token):
raise InvalidToken()
left, sep, _right = token.partition(":")
if (not sep) or (not left.isdigit()) or (len(left) < 3):
raise InvalidToken()
return token
@property @property
def bot(self) -> User: def bot(self) -> User:
""":class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`. """:class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`.
@ -2857,7 +2852,7 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
connect_timeout: ODVInput[float] = DEFAULT_NONE, connect_timeout: ODVInput[float] = DEFAULT_NONE,
pool_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE,
api_kwargs: JSONDict = None, api_kwargs: JSONDict = None,
) -> Optional[UserProfilePhotos]: ) -> UserProfilePhotos:
"""Use this method to get a list of profile pictures for a user. """Use this method to get a list of profile pictures for a user.
Args: Args:
@ -2907,7 +2902,7 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
api_kwargs=api_kwargs, api_kwargs=api_kwargs,
) )
return UserProfilePhotos.de_json(result, self) # type: ignore[arg-type] return UserProfilePhotos.de_json(result, self) # type: ignore[return-value, arg-type]
@_log @_log
async def get_file( async def get_file(

View file

@ -35,7 +35,7 @@ __all__ = (
"TimedOut", "TimedOut",
) )
from typing import Optional, Tuple, Union from typing import Tuple, Union
def _lstrip_str(in_s: str, lstr: str) -> str: def _lstrip_str(in_s: str, lstr: str) -> str:
@ -100,14 +100,10 @@ class InvalidToken(TelegramError):
.. versionadded:: 20.0 .. versionadded:: 20.0
""" """
__slots__ = ("_message",) __slots__ = ()
def __init__(self, message: str = None) -> None: def __init__(self, message: str = None) -> None:
self._message = message super().__init__("Invalid token" if message is None else message)
super().__init__("Invalid token" if self._message is None else self._message)
def __reduce__(self) -> Tuple[type, Tuple[Optional[str]]]: # type: ignore[override]
return self.__class__, (self._message,)
class NetworkError(TelegramError): class NetworkError(TelegramError):

View file

@ -312,20 +312,20 @@ class BaseRequest(
message += f"\nThe server response contained unknown parameters: {parameters}" message += f"\nThe server response contained unknown parameters: {parameters}"
if code == HTTPStatus.FORBIDDEN: if code == HTTPStatus.FORBIDDEN: # 403
raise Forbidden(message) raise Forbidden(message)
if code in (HTTPStatus.NOT_FOUND, HTTPStatus.UNAUTHORIZED): if code in (HTTPStatus.NOT_FOUND, HTTPStatus.UNAUTHORIZED): # 404 and 401
# TG returns 404 Not found for # TG returns 404 Not found for
# 1) malformed tokens # 1) malformed tokens
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod # 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
# We can basically rule out 2) since we don't let users make requests manually # We can basically rule out 2) since we don't let users make requests manually
# TG returns 401 Unauthorized for correctly formatted tokens that are not valid # TG returns 401 Unauthorized for correctly formatted tokens that are not valid
raise InvalidToken(message) raise InvalidToken(message)
if code == HTTPStatus.BAD_REQUEST: if code == HTTPStatus.BAD_REQUEST: # 400
raise BadRequest(message) raise BadRequest(message)
if code == HTTPStatus.CONFLICT: if code == HTTPStatus.CONFLICT: # 409
raise Conflict(message) raise Conflict(message)
if code == HTTPStatus.BAD_GATEWAY: if code == HTTPStatus.BAD_GATEWAY: # 502
raise NetworkError(description or "Bad Gateway") raise NetworkError(description or "Bad Gateway")
raise NetworkError(f"{message} ({code})") raise NetworkError(f"{message} ({code})")

View file

@ -183,22 +183,6 @@ class TestBot:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
@pytest.mark.parametrize(
"token",
argvalues=[
"123",
"12a:abcd1234",
"12:abcd1234",
"1234:abcd1234\n",
" 1234:abcd1234",
" 1234:abcd1234\r",
"1234:abcd 1234",
],
)
async def test_invalid_token(self, token):
with pytest.raises(InvalidToken, match="Invalid token"):
Bot(token)
async def test_initialize_and_shutdown(self, bot, monkeypatch): async def test_initialize_and_shutdown(self, bot, monkeypatch):
async def initialize(*args, **kwargs): async def initialize(*args, **kwargs):
self.test_flag = ["initialize"] self.test_flag = ["initialize"]
@ -307,12 +291,14 @@ class TestBot:
assert acd_bot.arbitrary_callback_data == acd assert acd_bot.arbitrary_callback_data == acd
assert acd_bot.callback_data_cache.maxsize == maxsize assert acd_bot.callback_data_cache.maxsize == maxsize
@flaky(3, 1) async def test_no_token_passed(self):
async def test_invalid_token_server_response(self, monkeypatch): with pytest.raises(InvalidToken, match="You must pass the token"):
monkeypatch.setattr("telegram.Bot._validate_token", lambda x, y: "") Bot("")
with pytest.raises(InvalidToken):
async with make_bot(token="12") as bot: async def test_invalid_token_server_response(self):
await bot.get_me() with pytest.raises(InvalidToken, match="The token `12` was rejected by the server."):
async with make_bot(token="12"):
pass
async def test_unknown_kwargs(self, bot, monkeypatch): async def test_unknown_kwargs(self, bot, monkeypatch):
async def post(url, request_data: RequestData, *args, **kwargs): async def post(url, request_data: RequestData, *args, **kwargs):

View file

@ -240,7 +240,7 @@ class TestRequest:
) )
with pytest.raises(exception_class, match="Test Message"): with pytest.raises(exception_class, match="Test Message"):
await httpx_request.post(None, None, None) await httpx_request.post("", None, None)
@pytest.mark.parametrize( @pytest.mark.parametrize(
["exception", "catch_class", "match"], ["exception", "catch_class", "match"],