diff --git a/telegram/files/file.py b/telegram/files/file.py index af723df3e..1caec12b2 100644 --- a/telegram/files/file.py +++ b/telegram/files/file.py @@ -74,32 +74,34 @@ class File(TelegramObject): that object using the ``out.write`` method. Note: - `custom_path` and `out` are mutually exclusive. + :attr:`custom_path` and :attr:`out` are mutually exclusive. Args: custom_path (:obj:`str`, optional): Custom path. - out (:obj:`object`, optional): A file-like object. Must be opened in binary mode, if - applicable. + out (:obj:`io.BufferedWriter`, optional): A file-like object. Must be opened for + writing in binary mode, if applicable. timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as the read timeout from the server (instead of the one specified during creation of the connection pool). + Returns: + :obj:`str` | :obj:`io.BufferedWriter`: The same object as :attr:`out` if specified. + Otherwise, returns the filename downloaded to. + Raises: - ValueError: If both ``custom_path`` and ``out`` are passed. + ValueError: If both :attr:`custom_path` and :attr:`out` are passed. """ if custom_path is not None and out is not None: raise ValueError('custom_path and out are mutually exclusive') # Convert any UTF-8 char into a url encoded ASCII string. - sres = urllib_parse.urlsplit(self.file_path) - url = urllib_parse.urlunsplit(urllib_parse.SplitResult( - sres.scheme, sres.netloc, urllib_parse.quote(sres.path), sres.query, sres.fragment)) + url = self._get_encoded_url() if out: buf = self.bot.request.retrieve(url) out.write(buf) - + return out else: if custom_path: filename = custom_path @@ -107,3 +109,27 @@ class File(TelegramObject): filename = basename(self.file_path) self.bot.request.download(url, filename, timeout=timeout) + return filename + + def _get_encoded_url(self): + """Convert any UTF-8 char in :obj:`File.file_path` into a url encoded ASCII string.""" + sres = urllib_parse.urlsplit(self.file_path) + return urllib_parse.urlunsplit(urllib_parse.SplitResult( + sres.scheme, sres.netloc, urllib_parse.quote(sres.path), sres.query, sres.fragment)) + + def download_as_bytearray(self, buf=None): + """Download this file and return it as a bytearray. + + Args: + buf (:obj:`bytearray`, optional): Extend the given bytearray with the downloaded data. + + Returns: + :obj:`bytearray`: The same object as :attr:`buf` if it was specified. Otherwise a newly + allocated :obj:`bytearray`. + + """ + if buf is None: + buf = bytearray() + + buf.extend(self.bot.request.retrieve(self._get_encoded_url())) + return buf diff --git a/tests/test_file.py b/tests/test_file.py index a01b03e72..6028139b8 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- # # A library that provides a Python interface to the Telegram Bot API # Copyright (C) 2015-2018 @@ -16,6 +17,8 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import os +from tempfile import TemporaryFile, mkstemp import pytest from flaky import flaky @@ -36,6 +39,7 @@ class TestFile(object): file_path = ( u'https://api.org/file/bot133505823:AAHZFMHno3mzVLErU5b5jJvaeG--qUyLyG0/document/file_3') file_size = 28232 + file_content = u'Saint-Saƫns'.encode('utf-8') # Intentionally contains unicode chars. def test_de_json(self, bot): json_dict = { @@ -65,11 +69,61 @@ class TestFile(object): def test_download(self, monkeypatch, file): def test(*args, **kwargs): - raise TelegramError('test worked') + return self.file_content - monkeypatch.setattr('telegram.utils.request.Request.download', test) - with pytest.raises(TelegramError, match='test worked'): - file.download() + monkeypatch.setattr('telegram.utils.request.Request.retrieve', test) + out_file = file.download() + + try: + with open(out_file, 'rb') as fobj: + assert fobj.read() == self.file_content + finally: + os.unlink(out_file) + + def test_download_custom_path(self, monkeypatch, file): + def test(*args, **kwargs): + return self.file_content + + monkeypatch.setattr('telegram.utils.request.Request.retrieve', test) + file_handle, custom_path = mkstemp() + try: + out_file = file.download(custom_path) + assert out_file == custom_path + + with open(out_file, 'rb') as fobj: + assert fobj.read() == self.file_content + finally: + os.close(file_handle) + os.unlink(custom_path) + + def test_download_file_obj(self, monkeypatch, file): + def test(*args, **kwargs): + return self.file_content + + monkeypatch.setattr('telegram.utils.request.Request.retrieve', test) + with TemporaryFile() as custom_fobj: + out_fobj = file.download(out=custom_fobj) + assert out_fobj is custom_fobj + + out_fobj.seek(0) + assert out_fobj.read() == self.file_content + + def test_download_bytearray(self, monkeypatch, file): + def test(*args, **kwargs): + return self.file_content + + monkeypatch.setattr('telegram.utils.request.Request.retrieve', test) + + # Check that a download to a newly allocated bytearray works. + buf = file.download_as_bytearray() + assert buf == bytearray(self.file_content) + + # Check that a download to a given bytearray works (extends the bytearray). + buf2 = buf[:] + buf3 = file.download_as_bytearray(buf=buf2) + assert buf3 is buf2 + assert buf2[len(buf):] == buf + assert buf2[:len(buf)] == buf def test_equality(self, bot): a = File(self.file_id, bot)