mirror of
https://github.com/teloxide/teloxide.git
synced 2024-12-22 22:46:39 +01:00
Allow sending impl AsyncRead
This commit adds `InputFile::read` constructor that creates `InputFile` from an `impl AsyncRead + Send + Unpin + 'static`. Internally this requires quite a bit of work, since we need to support cloning `InputFile`s but the `AsyncRead` trait only allows us reading it once. To support this, if `InputFile` detects that it's shared, it reads the contents of the `AsyncRead` into a buffer and then shares the buffer (or an error if it has occured).
This commit is contained in:
parent
a84e897db9
commit
5b4ed3faa9
2 changed files with 176 additions and 2 deletions
|
@ -39,7 +39,9 @@ derive_more = "0.99.9"
|
|||
mime = "0.3.16"
|
||||
thiserror = "1.0.20"
|
||||
once_cell = "1.5.0"
|
||||
takecell = "0.1"
|
||||
take_mut = "0.2"
|
||||
rc-box = "1.1.1"
|
||||
never = "0.1.0"
|
||||
chrono = { version = "0.4.19", default-features = false }
|
||||
either = "1.6.1"
|
||||
|
|
|
@ -4,11 +4,27 @@ use futures::{
|
|||
stream,
|
||||
};
|
||||
use once_cell::sync::OnceCell;
|
||||
use rc_box::ArcBox;
|
||||
use reqwest::{multipart::Part, Body};
|
||||
use serde::Serialize;
|
||||
use takecell::TakeCell;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, ReadBuf},
|
||||
sync::watch,
|
||||
};
|
||||
use tokio_util::codec::{Decoder, FramedRead};
|
||||
|
||||
use std::{borrow::Cow, fmt, future::Future, io, mem, path::PathBuf, sync::Arc};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
convert::{Infallible, TryFrom},
|
||||
fmt,
|
||||
future::Future,
|
||||
io, iter, mem,
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task,
|
||||
};
|
||||
|
||||
use crate::types::InputSticker;
|
||||
|
||||
|
@ -24,6 +40,7 @@ pub struct InputFile {
|
|||
|
||||
#[derive(Clone)]
|
||||
enum InnerFile {
|
||||
Read(Read),
|
||||
File(PathBuf),
|
||||
Bytes(bytes::Bytes),
|
||||
Url(url::Url),
|
||||
|
@ -87,6 +104,14 @@ impl InputFile {
|
|||
self
|
||||
}
|
||||
|
||||
/// Creates an `InputFile` from a in-memory bytes.
|
||||
///
|
||||
/// Note: in some cases (e.g. sending the same `InputFile` multiple times)
|
||||
/// this may read the whole `impl AsyncRead` into memory.
|
||||
pub fn read(it: impl AsyncRead + Send + Unpin + 'static) -> Self {
|
||||
Self::new(Read(Read::new(Arc::new(TakeCell::new(it)))))
|
||||
}
|
||||
|
||||
/// Shorthand for `Self { file_name: None, inner, id: default() }`
|
||||
/// (private because `InnerFile` iы private implementation detail)
|
||||
fn new(inner: InnerFile) -> Self {
|
||||
|
@ -157,6 +182,7 @@ impl InputFile {
|
|||
impl fmt::Debug for InnerFile {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Read(_) => f.debug_struct("Read").finish_non_exhaustive(),
|
||||
File(path) => f.debug_struct("File").field("path", path).finish(),
|
||||
Bytes(bytes) if f.alternate() => f.debug_tuple("Memory").field(bytes).finish(),
|
||||
Bytes(_) => f.debug_struct("Memory").finish_non_exhaustive(),
|
||||
|
@ -207,14 +233,160 @@ impl InputFile {
|
|||
}
|
||||
Bytes(data) => {
|
||||
let stream = Part::stream(data).file_name(filename);
|
||||
Some(Either::Right(ready(stream)))
|
||||
Some(Either::Right(Either::Left(ready(stream))))
|
||||
}
|
||||
Read(read) => Some(Either::Right(Either::Right(read.into_part(filename)))),
|
||||
};
|
||||
|
||||
file_part
|
||||
}
|
||||
}
|
||||
|
||||
/// Adaptor for `AsyncRead` that allows clonning and converting to
|
||||
/// `multipart/form-data`
|
||||
#[derive(Clone)]
|
||||
struct Read {
|
||||
inner: Arc<TakeCell<dyn AsyncRead + Send + Unpin>>,
|
||||
buf: Arc<OnceCell<Result<Vec<Bytes>, Arc<io::Error>>>>,
|
||||
notify: Arc<watch::Sender<()>>,
|
||||
wait: watch::Receiver<()>,
|
||||
}
|
||||
|
||||
impl Read {
|
||||
fn new(it: Arc<TakeCell<dyn AsyncRead + Send + Unpin>>) -> Self {
|
||||
let (tx, rx) = watch::channel(());
|
||||
|
||||
Self {
|
||||
inner: it,
|
||||
buf: Arc::default(),
|
||||
notify: Arc::new(tx),
|
||||
wait: rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn into_part(mut self, filename: Cow<'static, str>) -> Part {
|
||||
if !self.inner.is_taken() {
|
||||
let res = ArcBox::<TakeCell<dyn AsyncRead + Send + Unpin>>::try_from(self.inner);
|
||||
match res {
|
||||
// Fast/easy path: this is the only file copy, so we can just forward the underlying
|
||||
// `dyn AsynсRead` via some adaptors to reqwest.
|
||||
Ok(arc_box) => {
|
||||
let fr = FramedRead::new(ExclusiveArcAsyncRead(arc_box), BytesDecoder);
|
||||
|
||||
let body = Body::wrap_stream(fr);
|
||||
return Part::stream(body).file_name(filename);
|
||||
}
|
||||
// move the arc back into `self`
|
||||
Err(i) => self.inner = i,
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: either wait until someone will read the whole `dyn AsynсRead` into
|
||||
// a buffer, or be the one who reads
|
||||
let body = self.into_shared_body().await;
|
||||
|
||||
Part::stream(body).file_name(filename)
|
||||
}
|
||||
|
||||
async fn into_shared_body(mut self) -> Body {
|
||||
match self.inner.take() {
|
||||
// Read `dyn AsyncRead` into a buffer
|
||||
Some(mut read_ref) => {
|
||||
// Chunk size, arbitrary chosen to be 1KiB
|
||||
const CHUNK: usize = 1024;
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut bytes = BytesMut::with_capacity(CHUNK);
|
||||
|
||||
let res = loop {
|
||||
match (&mut read_ref).read_buf(&mut bytes).await {
|
||||
// eof
|
||||
Ok(0) if bytes.len() < bytes.capacity() => {
|
||||
chunks.push(bytes.freeze());
|
||||
|
||||
break Ok(chunks);
|
||||
}
|
||||
|
||||
// No space left in bytes, allocate a new chunk
|
||||
Ok(0) => {
|
||||
chunks.push(bytes.freeze());
|
||||
bytes = BytesMut::with_capacity(CHUNK);
|
||||
}
|
||||
|
||||
// keep reading into the same chunk
|
||||
Ok(_) => {}
|
||||
|
||||
// i/o error
|
||||
Err(err) => break Err(Arc::new(err)),
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize `buf` with the result.
|
||||
// Error indicates that the `buf` was already initialized, but this can't happen
|
||||
// since we synchronize through other means.
|
||||
let r = self.buf.set(res);
|
||||
debug_assert!(r.is_ok());
|
||||
|
||||
// Notify other tasks that `buf` is initialized.
|
||||
// Error indicates that there is no one to notify anymore, but we don't care.
|
||||
let _ = self.notify.send(());
|
||||
}
|
||||
|
||||
// Wait until `dyn AsynсRead` is read into a buffer, if it hasn't been read yet
|
||||
None if self.buf.get().is_none() => {
|
||||
// Error indicates that the sender was dropped, by we hold `Arc<Sender>`, so
|
||||
// this can't happen
|
||||
let _ = self.wait.changed().await;
|
||||
}
|
||||
|
||||
// Someone else has already initialized the buffer
|
||||
None => {}
|
||||
};
|
||||
|
||||
let buf = self.buf;
|
||||
// unwrap: `OnceCell` is initialized in the match above before sending
|
||||
// notification, so at this point it's already initialized.
|
||||
match buf.get().unwrap() {
|
||||
Ok(_) => {
|
||||
// We can't use `.iter()` here, because the iterator must capture `buf`
|
||||
let mut i = 0;
|
||||
let iter = iter::from_fn(move || match buf.get().unwrap() {
|
||||
Ok(buf) if i >= buf.len() => None,
|
||||
Ok(buf) => {
|
||||
let res = buf[i].clone();
|
||||
i += 1;
|
||||
Some(Ok::<_, Infallible>(res))
|
||||
}
|
||||
// We've just checked in the above match, it's `Ok(_)`
|
||||
Err(_) => unreachable!(),
|
||||
});
|
||||
|
||||
Body::wrap_stream(stream::iter(iter))
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
let err = Err::<Bytes, _>(Arc::clone(err));
|
||||
Body::wrap_stream(stream::iter(iter::once(err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper over an `ArcBox` that implements `AsyncRead`.
|
||||
struct ExclusiveArcAsyncRead(ArcBox<TakeCell<dyn AsyncRead + Send + Unpin>>);
|
||||
|
||||
impl AsyncRead for ExclusiveArcAsyncRead {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let Self(inner) = Pin::get_mut(self);
|
||||
let read: &mut (dyn AsyncRead + Unpin) = inner.get();
|
||||
Pin::new(read).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
struct BytesDecoder;
|
||||
|
||||
impl Decoder for BytesDecoder {
|
||||
|
|
Loading…
Reference in a new issue