mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-03 09:49:21 +01:00
Drop Manual Token Validation (#3167)
This commit is contained in:
parent
c28ad86214
commit
143db5fc9d
5 changed files with 28 additions and 51 deletions
|
@ -197,7 +197,9 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
|
||||||
private_key: bytes = None,
|
private_key: bytes = None,
|
||||||
private_key_password: bytes = None,
|
private_key_password: bytes = None,
|
||||||
):
|
):
|
||||||
self.token = self._validate_token(token)
|
if not token:
|
||||||
|
raise InvalidToken("You must pass the token you received from https://t.me/Botfather!")
|
||||||
|
self.token = token
|
||||||
|
|
||||||
self.base_url = base_url + self.token
|
self.base_url = base_url + self.token
|
||||||
self.base_file_url = base_file_url + self.token
|
self.base_file_url = base_file_url + self.token
|
||||||
|
@ -372,7 +374,12 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
|
||||||
return
|
return
|
||||||
|
|
||||||
await asyncio.gather(self._request[0].initialize(), self._request[1].initialize())
|
await asyncio.gather(self._request[0].initialize(), self._request[1].initialize())
|
||||||
await self.get_me()
|
# Since the bot is to be initialized only once, we can also use it for
|
||||||
|
# verifying the token passed and raising an exception if it's invalid.
|
||||||
|
try:
|
||||||
|
await self.get_me()
|
||||||
|
except InvalidToken as exc:
|
||||||
|
raise InvalidToken(f"The token `{self.token}` was rejected by the server.") from exc
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -418,18 +425,6 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
|
||||||
"""
|
"""
|
||||||
return self._request[1]
|
return self._request[1]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_token(token: str) -> str:
|
|
||||||
"""A very basic validation on token."""
|
|
||||||
if any(x.isspace() for x in token):
|
|
||||||
raise InvalidToken()
|
|
||||||
|
|
||||||
left, sep, _right = token.partition(":")
|
|
||||||
if (not sep) or (not left.isdigit()) or (len(left) < 3):
|
|
||||||
raise InvalidToken()
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> User:
|
def bot(self) -> User:
|
||||||
""":class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`.
|
""":class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`.
|
||||||
|
@ -2857,7 +2852,7 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
|
||||||
connect_timeout: ODVInput[float] = DEFAULT_NONE,
|
connect_timeout: ODVInput[float] = DEFAULT_NONE,
|
||||||
pool_timeout: ODVInput[float] = DEFAULT_NONE,
|
pool_timeout: ODVInput[float] = DEFAULT_NONE,
|
||||||
api_kwargs: JSONDict = None,
|
api_kwargs: JSONDict = None,
|
||||||
) -> Optional[UserProfilePhotos]:
|
) -> UserProfilePhotos:
|
||||||
"""Use this method to get a list of profile pictures for a user.
|
"""Use this method to get a list of profile pictures for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -2907,7 +2902,7 @@ class Bot(TelegramObject, AbstractAsyncContextManager):
|
||||||
api_kwargs=api_kwargs,
|
api_kwargs=api_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return UserProfilePhotos.de_json(result, self) # type: ignore[arg-type]
|
return UserProfilePhotos.de_json(result, self) # type: ignore[return-value, arg-type]
|
||||||
|
|
||||||
@_log
|
@_log
|
||||||
async def get_file(
|
async def get_file(
|
||||||
|
|
|
@ -35,7 +35,7 @@ __all__ = (
|
||||||
"TimedOut",
|
"TimedOut",
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def _lstrip_str(in_s: str, lstr: str) -> str:
|
def _lstrip_str(in_s: str, lstr: str) -> str:
|
||||||
|
@ -100,14 +100,10 @@ class InvalidToken(TelegramError):
|
||||||
.. versionadded:: 20.0
|
.. versionadded:: 20.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ("_message",)
|
__slots__ = ()
|
||||||
|
|
||||||
def __init__(self, message: str = None) -> None:
|
def __init__(self, message: str = None) -> None:
|
||||||
self._message = message
|
super().__init__("Invalid token" if message is None else message)
|
||||||
super().__init__("Invalid token" if self._message is None else self._message)
|
|
||||||
|
|
||||||
def __reduce__(self) -> Tuple[type, Tuple[Optional[str]]]: # type: ignore[override]
|
|
||||||
return self.__class__, (self._message,)
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkError(TelegramError):
|
class NetworkError(TelegramError):
|
||||||
|
|
|
@ -312,20 +312,20 @@ class BaseRequest(
|
||||||
|
|
||||||
message += f"\nThe server response contained unknown parameters: {parameters}"
|
message += f"\nThe server response contained unknown parameters: {parameters}"
|
||||||
|
|
||||||
if code == HTTPStatus.FORBIDDEN:
|
if code == HTTPStatus.FORBIDDEN: # 403
|
||||||
raise Forbidden(message)
|
raise Forbidden(message)
|
||||||
if code in (HTTPStatus.NOT_FOUND, HTTPStatus.UNAUTHORIZED):
|
if code in (HTTPStatus.NOT_FOUND, HTTPStatus.UNAUTHORIZED): # 404 and 401
|
||||||
# TG returns 404 Not found for
|
# TG returns 404 Not found for
|
||||||
# 1) malformed tokens
|
# 1) malformed tokens
|
||||||
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
|
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
|
||||||
# We can basically rule out 2) since we don't let users make requests manually
|
# We can basically rule out 2) since we don't let users make requests manually
|
||||||
# TG returns 401 Unauthorized for correctly formatted tokens that are not valid
|
# TG returns 401 Unauthorized for correctly formatted tokens that are not valid
|
||||||
raise InvalidToken(message)
|
raise InvalidToken(message)
|
||||||
if code == HTTPStatus.BAD_REQUEST:
|
if code == HTTPStatus.BAD_REQUEST: # 400
|
||||||
raise BadRequest(message)
|
raise BadRequest(message)
|
||||||
if code == HTTPStatus.CONFLICT:
|
if code == HTTPStatus.CONFLICT: # 409
|
||||||
raise Conflict(message)
|
raise Conflict(message)
|
||||||
if code == HTTPStatus.BAD_GATEWAY:
|
if code == HTTPStatus.BAD_GATEWAY: # 502
|
||||||
raise NetworkError(description or "Bad Gateway")
|
raise NetworkError(description or "Bad Gateway")
|
||||||
raise NetworkError(f"{message} ({code})")
|
raise NetworkError(f"{message} ({code})")
|
||||||
|
|
||||||
|
|
|
@ -183,22 +183,6 @@ class TestBot:
|
||||||
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"
|
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"
|
||||||
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
|
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"token",
|
|
||||||
argvalues=[
|
|
||||||
"123",
|
|
||||||
"12a:abcd1234",
|
|
||||||
"12:abcd1234",
|
|
||||||
"1234:abcd1234\n",
|
|
||||||
" 1234:abcd1234",
|
|
||||||
" 1234:abcd1234\r",
|
|
||||||
"1234:abcd 1234",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def test_invalid_token(self, token):
|
|
||||||
with pytest.raises(InvalidToken, match="Invalid token"):
|
|
||||||
Bot(token)
|
|
||||||
|
|
||||||
async def test_initialize_and_shutdown(self, bot, monkeypatch):
|
async def test_initialize_and_shutdown(self, bot, monkeypatch):
|
||||||
async def initialize(*args, **kwargs):
|
async def initialize(*args, **kwargs):
|
||||||
self.test_flag = ["initialize"]
|
self.test_flag = ["initialize"]
|
||||||
|
@ -307,12 +291,14 @@ class TestBot:
|
||||||
assert acd_bot.arbitrary_callback_data == acd
|
assert acd_bot.arbitrary_callback_data == acd
|
||||||
assert acd_bot.callback_data_cache.maxsize == maxsize
|
assert acd_bot.callback_data_cache.maxsize == maxsize
|
||||||
|
|
||||||
@flaky(3, 1)
|
async def test_no_token_passed(self):
|
||||||
async def test_invalid_token_server_response(self, monkeypatch):
|
with pytest.raises(InvalidToken, match="You must pass the token"):
|
||||||
monkeypatch.setattr("telegram.Bot._validate_token", lambda x, y: "")
|
Bot("")
|
||||||
with pytest.raises(InvalidToken):
|
|
||||||
async with make_bot(token="12") as bot:
|
async def test_invalid_token_server_response(self):
|
||||||
await bot.get_me()
|
with pytest.raises(InvalidToken, match="The token `12` was rejected by the server."):
|
||||||
|
async with make_bot(token="12"):
|
||||||
|
pass
|
||||||
|
|
||||||
async def test_unknown_kwargs(self, bot, monkeypatch):
|
async def test_unknown_kwargs(self, bot, monkeypatch):
|
||||||
async def post(url, request_data: RequestData, *args, **kwargs):
|
async def post(url, request_data: RequestData, *args, **kwargs):
|
||||||
|
|
|
@ -240,7 +240,7 @@ class TestRequest:
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(exception_class, match="Test Message"):
|
with pytest.raises(exception_class, match="Test Message"):
|
||||||
await httpx_request.post(None, None, None)
|
await httpx_request.post("", None, None)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["exception", "catch_class", "match"],
|
["exception", "catch_class", "match"],
|
||||||
|
|
Loading…
Reference in a new issue