diff --git a/Cargo.toml b/Cargo.toml index 1d22bbe2..d0b0923c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,9 @@ derive_more = "0.99.9" mime = "0.3.16" thiserror = "1.0.20" once_cell = "1.5.0" +takecell = "0.1" take_mut = "0.2" +rc-box = "1.1.1" never = "0.1.0" chrono = { version = "0.4.19", default-features = false } either = "1.6.1" diff --git a/src/types/input_file.rs b/src/types/input_file.rs index 26ee7b10..aadfe568 100644 --- a/src/types/input_file.rs +++ b/src/types/input_file.rs @@ -4,11 +4,27 @@ use futures::{ stream, }; use once_cell::sync::OnceCell; +use rc_box::ArcBox; use reqwest::{multipart::Part, Body}; use serde::Serialize; +use takecell::TakeCell; +use tokio::{ + io::{AsyncRead, AsyncReadExt, ReadBuf}, + sync::watch, +}; use tokio_util::codec::{Decoder, FramedRead}; -use std::{borrow::Cow, fmt, future::Future, io, mem, path::PathBuf, sync::Arc}; +use std::{ + borrow::Cow, + convert::{Infallible, TryFrom}, + fmt, + future::Future, + io, iter, mem, + path::PathBuf, + pin::Pin, + sync::Arc, + task, +}; use crate::types::InputSticker; @@ -24,6 +40,7 @@ pub struct InputFile { #[derive(Clone)] enum InnerFile { + Read(Read), File(PathBuf), Bytes(bytes::Bytes), Url(url::Url), @@ -87,6 +104,14 @@ impl InputFile { self } + /// Creates an `InputFile` from a in-memory bytes. + /// + /// Note: in some cases (e.g. sending the same `InputFile` multiple times) + /// this may read the whole `impl AsyncRead` into memory. + pub fn read(it: impl AsyncRead + Send + Unpin + 'static) -> Self { + Self::new(Read(Read::new(Arc::new(TakeCell::new(it))))) + } + /// Shorthand for `Self { file_name: None, inner, id: default() }` /// (private because `InnerFile` iы private implementation detail) fn new(inner: InnerFile) -> Self { @@ -157,6 +182,7 @@ impl InputFile { impl fmt::Debug for InnerFile { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Read(_) => f.debug_struct("Read").finish_non_exhaustive(), File(path) => f.debug_struct("File").field("path", path).finish(), Bytes(bytes) if f.alternate() => f.debug_tuple("Memory").field(bytes).finish(), Bytes(_) => f.debug_struct("Memory").finish_non_exhaustive(), @@ -207,14 +233,160 @@ impl InputFile { } Bytes(data) => { let stream = Part::stream(data).file_name(filename); - Some(Either::Right(ready(stream))) + Some(Either::Right(Either::Left(ready(stream)))) } + Read(read) => Some(Either::Right(Either::Right(read.into_part(filename)))), }; file_part } } +/// Adaptor for `AsyncRead` that allows clonning and converting to +/// `multipart/form-data` +#[derive(Clone)] +struct Read { + inner: Arc>, + buf: Arc, Arc>>>, + notify: Arc>, + wait: watch::Receiver<()>, +} + +impl Read { + fn new(it: Arc>) -> Self { + let (tx, rx) = watch::channel(()); + + Self { + inner: it, + buf: Arc::default(), + notify: Arc::new(tx), + wait: rx, + } + } + + pub(crate) async fn into_part(mut self, filename: Cow<'static, str>) -> Part { + if !self.inner.is_taken() { + let res = ArcBox::>::try_from(self.inner); + match res { + // Fast/easy path: this is the only file copy, so we can just forward the underlying + // `dyn AsynсRead` via some adaptors to reqwest. + Ok(arc_box) => { + let fr = FramedRead::new(ExclusiveArcAsyncRead(arc_box), BytesDecoder); + + let body = Body::wrap_stream(fr); + return Part::stream(body).file_name(filename); + } + // move the arc back into `self` + Err(i) => self.inner = i, + } + } + + // Slow path: either wait until someone will read the whole `dyn AsynсRead` into + // a buffer, or be the one who reads + let body = self.into_shared_body().await; + + Part::stream(body).file_name(filename) + } + + async fn into_shared_body(mut self) -> Body { + match self.inner.take() { + // Read `dyn AsyncRead` into a buffer + Some(mut read_ref) => { + // Chunk size, arbitrary chosen to be 1KiB + const CHUNK: usize = 1024; + + let mut chunks = Vec::new(); + let mut bytes = BytesMut::with_capacity(CHUNK); + + let res = loop { + match (&mut read_ref).read_buf(&mut bytes).await { + // eof + Ok(0) if bytes.len() < bytes.capacity() => { + chunks.push(bytes.freeze()); + + break Ok(chunks); + } + + // No space left in bytes, allocate a new chunk + Ok(0) => { + chunks.push(bytes.freeze()); + bytes = BytesMut::with_capacity(CHUNK); + } + + // keep reading into the same chunk + Ok(_) => {} + + // i/o error + Err(err) => break Err(Arc::new(err)), + } + }; + + // Initialize `buf` with the result. + // Error indicates that the `buf` was already initialized, but this can't happen + // since we synchronize through other means. + let r = self.buf.set(res); + debug_assert!(r.is_ok()); + + // Notify other tasks that `buf` is initialized. + // Error indicates that there is no one to notify anymore, but we don't care. + let _ = self.notify.send(()); + } + + // Wait until `dyn AsynсRead` is read into a buffer, if it hasn't been read yet + None if self.buf.get().is_none() => { + // Error indicates that the sender was dropped, by we hold `Arc`, so + // this can't happen + let _ = self.wait.changed().await; + } + + // Someone else has already initialized the buffer + None => {} + }; + + let buf = self.buf; + // unwrap: `OnceCell` is initialized in the match above before sending + // notification, so at this point it's already initialized. + match buf.get().unwrap() { + Ok(_) => { + // We can't use `.iter()` here, because the iterator must capture `buf` + let mut i = 0; + let iter = iter::from_fn(move || match buf.get().unwrap() { + Ok(buf) if i >= buf.len() => None, + Ok(buf) => { + let res = buf[i].clone(); + i += 1; + Some(Ok::<_, Infallible>(res)) + } + // We've just checked in the above match, it's `Ok(_)` + Err(_) => unreachable!(), + }); + + Body::wrap_stream(stream::iter(iter)) + } + + Err(err) => { + let err = Err::(Arc::clone(err)); + Body::wrap_stream(stream::iter(iter::once(err))) + } + } + } +} + +/// Wrapper over an `ArcBox` that implements `AsyncRead`. +struct ExclusiveArcAsyncRead(ArcBox>); + +impl AsyncRead for ExclusiveArcAsyncRead { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + let Self(inner) = Pin::get_mut(self); + let read: &mut (dyn AsyncRead + Unpin) = inner.get(); + Pin::new(read).poll_read(cx, buf) + } +} + struct BytesDecoder; impl Decoder for BytesDecoder {