mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-23 08:47:03 +01:00
be71e7b286
* Fix vulnerability in example-stream-to-file example * Save files to separate directory
144 lines
4.2 KiB
Rust
144 lines
4.2 KiB
Rust
//! Run with
|
|
//!
|
|
//! ```not_rust
|
|
//! cd examples && cargo run -p example-stream-to-file
|
|
//! ```
|
|
|
|
use axum::{
|
|
body::Bytes,
|
|
extract::{BodyStream, Multipart, Path},
|
|
http::StatusCode,
|
|
response::{Html, Redirect},
|
|
routing::{get, post},
|
|
BoxError, Router,
|
|
};
|
|
use futures::{Stream, TryStreamExt};
|
|
use std::{io, net::SocketAddr};
|
|
use tokio::{fs::File, io::BufWriter};
|
|
use tokio_util::io::StreamReader;
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
|
|
const UPLOADS_DIRECTORY: &str = "uploads";
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
tracing_subscriber::registry()
|
|
.with(tracing_subscriber::EnvFilter::new(
|
|
std::env::var("RUST_LOG").unwrap_or_else(|_| "example_stream_to_file=debug".into()),
|
|
))
|
|
.with(tracing_subscriber::fmt::layer())
|
|
.init();
|
|
|
|
// save files to a separte directory to not override files in the current directory
|
|
tokio::fs::create_dir(UPLOADS_DIRECTORY)
|
|
.await
|
|
.expect("failed to create `uploads` directory");
|
|
|
|
let app = Router::new()
|
|
.route("/", get(show_form).post(accept_form))
|
|
.route("/file/:file_name", post(save_request_body));
|
|
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
|
tracing::debug!("listening on {}", addr);
|
|
axum::Server::bind(&addr)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
// Handler that streams the request body to a file.
|
|
//
|
|
// POST'ing to `/file/foo.txt` will create a file called `foo.txt`.
|
|
async fn save_request_body(
|
|
Path(file_name): Path<String>,
|
|
body: BodyStream,
|
|
) -> Result<(), (StatusCode, String)> {
|
|
stream_to_file(&file_name, body).await
|
|
}
|
|
|
|
// Handler that returns HTML for a multipart form.
|
|
async fn show_form() -> Html<&'static str> {
|
|
Html(
|
|
r#"
|
|
<!doctype html>
|
|
<html>
|
|
<head>
|
|
<title>Upload something!</title>
|
|
</head>
|
|
<body>
|
|
<form action="/" method="post" enctype="multipart/form-data">
|
|
<div>
|
|
<label>
|
|
Upload file:
|
|
<input type="file" name="file" multiple>
|
|
</label>
|
|
</div>
|
|
|
|
<div>
|
|
<input type="submit" value="Upload files">
|
|
</div>
|
|
</form>
|
|
</body>
|
|
</html>
|
|
"#,
|
|
)
|
|
}
|
|
|
|
// Handler that accepts a multipart form upload and streams each field to a file.
|
|
async fn accept_form(mut multipart: Multipart) -> Result<Redirect, (StatusCode, String)> {
|
|
while let Some(field) = multipart.next_field().await.unwrap() {
|
|
let file_name = if let Some(file_name) = field.file_name() {
|
|
file_name.to_owned()
|
|
} else {
|
|
continue;
|
|
};
|
|
|
|
stream_to_file(&file_name, field).await?;
|
|
}
|
|
|
|
Ok(Redirect::to("/"))
|
|
}
|
|
|
|
// Save a `Stream` to a file
|
|
async fn stream_to_file<S, E>(path: &str, stream: S) -> Result<(), (StatusCode, String)>
|
|
where
|
|
S: Stream<Item = Result<Bytes, E>>,
|
|
E: Into<BoxError>,
|
|
{
|
|
if !path_is_valid(path) {
|
|
return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned()));
|
|
}
|
|
|
|
async {
|
|
// Convert the stream into an `AsyncRead`.
|
|
let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
|
|
let body_reader = StreamReader::new(body_with_io_error);
|
|
futures::pin_mut!(body_reader);
|
|
|
|
// Create the file. `File` implements `AsyncWrite`.
|
|
let path = std::path::Path::new(UPLOADS_DIRECTORY).join(path);
|
|
let mut file = BufWriter::new(File::create(path).await?);
|
|
|
|
// Copy the body into the file.
|
|
tokio::io::copy(&mut body_reader, &mut file).await?;
|
|
|
|
Ok::<_, io::Error>(())
|
|
}
|
|
.await
|
|
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))
|
|
}
|
|
|
|
// to prevent directory traversal attacks we ensure the path conists of exactly one normal
|
|
// component
|
|
fn path_is_valid(path: &str) -> bool {
|
|
let path = std::path::Path::new(&*path);
|
|
let mut components = path.components().peekable();
|
|
|
|
if let Some(first) = components.peek() {
|
|
if !matches!(first, std::path::Component::Normal(_)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
components.count() == 1
|
|
}
|