Allow sending impl AsyncRead

This commit adds `InputFile::read` constructor that creates `InputFile`
from an `impl AsyncRead + Send + Unpin + 'static`.

Internally this requires quite a bit of work, since we need to support
cloning `InputFile`s but the `AsyncRead` trait only allows us reading it
once.

To support this, if `InputFile` detects that it's shared, it reads the
contents of the `AsyncRead` into a buffer and then shares the buffer
(or an error if it has occured).
This commit is contained in:
Maybe Waffle 2022-01-13 16:46:37 +03:00
parent a84e897db9
commit 5b4ed3faa9
2 changed files with 176 additions and 2 deletions

View file

@ -39,7 +39,9 @@ derive_more = "0.99.9"
mime = "0.3.16"
thiserror = "1.0.20"
once_cell = "1.5.0"
takecell = "0.1"
take_mut = "0.2"
rc-box = "1.1.1"
never = "0.1.0"
chrono = { version = "0.4.19", default-features = false }
either = "1.6.1"

View file

@ -4,11 +4,27 @@ use futures::{
stream,
};
use once_cell::sync::OnceCell;
use rc_box::ArcBox;
use reqwest::{multipart::Part, Body};
use serde::Serialize;
use takecell::TakeCell;
use tokio::{
io::{AsyncRead, AsyncReadExt, ReadBuf},
sync::watch,
};
use tokio_util::codec::{Decoder, FramedRead};
use std::{borrow::Cow, fmt, future::Future, io, mem, path::PathBuf, sync::Arc};
use std::{
borrow::Cow,
convert::{Infallible, TryFrom},
fmt,
future::Future,
io, iter, mem,
path::PathBuf,
pin::Pin,
sync::Arc,
task,
};
use crate::types::InputSticker;
@ -24,6 +40,7 @@ pub struct InputFile {
#[derive(Clone)]
enum InnerFile {
Read(Read),
File(PathBuf),
Bytes(bytes::Bytes),
Url(url::Url),
@ -87,6 +104,14 @@ impl InputFile {
self
}
/// Creates an `InputFile` from a in-memory bytes.
///
/// Note: in some cases (e.g. sending the same `InputFile` multiple times)
/// this may read the whole `impl AsyncRead` into memory.
pub fn read(it: impl AsyncRead + Send + Unpin + 'static) -> Self {
Self::new(Read(Read::new(Arc::new(TakeCell::new(it)))))
}
/// Shorthand for `Self { file_name: None, inner, id: default() }`
/// (private because `InnerFile` iы private implementation detail)
fn new(inner: InnerFile) -> Self {
@ -157,6 +182,7 @@ impl InputFile {
impl fmt::Debug for InnerFile {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Read(_) => f.debug_struct("Read").finish_non_exhaustive(),
File(path) => f.debug_struct("File").field("path", path).finish(),
Bytes(bytes) if f.alternate() => f.debug_tuple("Memory").field(bytes).finish(),
Bytes(_) => f.debug_struct("Memory").finish_non_exhaustive(),
@ -207,14 +233,160 @@ impl InputFile {
}
Bytes(data) => {
let stream = Part::stream(data).file_name(filename);
Some(Either::Right(ready(stream)))
Some(Either::Right(Either::Left(ready(stream))))
}
Read(read) => Some(Either::Right(Either::Right(read.into_part(filename)))),
};
file_part
}
}
/// Adaptor for `AsyncRead` that allows clonning and converting to
/// `multipart/form-data`
#[derive(Clone)]
struct Read {
inner: Arc<TakeCell<dyn AsyncRead + Send + Unpin>>,
buf: Arc<OnceCell<Result<Vec<Bytes>, Arc<io::Error>>>>,
notify: Arc<watch::Sender<()>>,
wait: watch::Receiver<()>,
}
impl Read {
fn new(it: Arc<TakeCell<dyn AsyncRead + Send + Unpin>>) -> Self {
let (tx, rx) = watch::channel(());
Self {
inner: it,
buf: Arc::default(),
notify: Arc::new(tx),
wait: rx,
}
}
pub(crate) async fn into_part(mut self, filename: Cow<'static, str>) -> Part {
if !self.inner.is_taken() {
let res = ArcBox::<TakeCell<dyn AsyncRead + Send + Unpin>>::try_from(self.inner);
match res {
// Fast/easy path: this is the only file copy, so we can just forward the underlying
// `dyn AsynсRead` via some adaptors to reqwest.
Ok(arc_box) => {
let fr = FramedRead::new(ExclusiveArcAsyncRead(arc_box), BytesDecoder);
let body = Body::wrap_stream(fr);
return Part::stream(body).file_name(filename);
}
// move the arc back into `self`
Err(i) => self.inner = i,
}
}
// Slow path: either wait until someone will read the whole `dyn AsynсRead` into
// a buffer, or be the one who reads
let body = self.into_shared_body().await;
Part::stream(body).file_name(filename)
}
async fn into_shared_body(mut self) -> Body {
match self.inner.take() {
// Read `dyn AsyncRead` into a buffer
Some(mut read_ref) => {
// Chunk size, arbitrary chosen to be 1KiB
const CHUNK: usize = 1024;
let mut chunks = Vec::new();
let mut bytes = BytesMut::with_capacity(CHUNK);
let res = loop {
match (&mut read_ref).read_buf(&mut bytes).await {
// eof
Ok(0) if bytes.len() < bytes.capacity() => {
chunks.push(bytes.freeze());
break Ok(chunks);
}
// No space left in bytes, allocate a new chunk
Ok(0) => {
chunks.push(bytes.freeze());
bytes = BytesMut::with_capacity(CHUNK);
}
// keep reading into the same chunk
Ok(_) => {}
// i/o error
Err(err) => break Err(Arc::new(err)),
}
};
// Initialize `buf` with the result.
// Error indicates that the `buf` was already initialized, but this can't happen
// since we synchronize through other means.
let r = self.buf.set(res);
debug_assert!(r.is_ok());
// Notify other tasks that `buf` is initialized.
// Error indicates that there is no one to notify anymore, but we don't care.
let _ = self.notify.send(());
}
// Wait until `dyn AsynсRead` is read into a buffer, if it hasn't been read yet
None if self.buf.get().is_none() => {
// Error indicates that the sender was dropped, by we hold `Arc<Sender>`, so
// this can't happen
let _ = self.wait.changed().await;
}
// Someone else has already initialized the buffer
None => {}
};
let buf = self.buf;
// unwrap: `OnceCell` is initialized in the match above before sending
// notification, so at this point it's already initialized.
match buf.get().unwrap() {
Ok(_) => {
// We can't use `.iter()` here, because the iterator must capture `buf`
let mut i = 0;
let iter = iter::from_fn(move || match buf.get().unwrap() {
Ok(buf) if i >= buf.len() => None,
Ok(buf) => {
let res = buf[i].clone();
i += 1;
Some(Ok::<_, Infallible>(res))
}
// We've just checked in the above match, it's `Ok(_)`
Err(_) => unreachable!(),
});
Body::wrap_stream(stream::iter(iter))
}
Err(err) => {
let err = Err::<Bytes, _>(Arc::clone(err));
Body::wrap_stream(stream::iter(iter::once(err)))
}
}
}
}
/// Wrapper over an `ArcBox` that implements `AsyncRead`.
struct ExclusiveArcAsyncRead(ArcBox<TakeCell<dyn AsyncRead + Send + Unpin>>);
impl AsyncRead for ExclusiveArcAsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let Self(inner) = Pin::get_mut(self);
let read: &mut (dyn AsyncRead + Unpin) = inner.get();
Pin::new(read).poll_read(cx, buf)
}
}
struct BytesDecoder;
impl Decoder for BytesDecoder {