Fix stripping prefix when nesting at / (#91)

* Fix stripping prefix when nesting at `/`

Fixes https://github.com/tokio-rs/axum/issues/88

* changelog
This commit is contained in:
David Pedersen 2021-08-02 22:40:33 +02:00 committed by GitHub
parent c6b7ad0f33
commit 6a078ddb71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 135 additions and 101 deletions

View file

@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
None.
- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
## Breaking changes

View file

@ -958,15 +958,17 @@ where
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
} else {
path_and_query.path()
};
if new_path.is_empty() {
new_path = "/";
}
let new_path = if new_path.starts_with('/') {
Cow::Borrowed(new_path)
} else {
Cow::Owned(format!("/{}", new_path))
};
if let Some(query) = path_and_query.query() {
Some(

View file

@ -1,10 +1,10 @@
use crate::{
extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::MethodFilter,
service,
extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::nest,
routing::MethodFilter, service,
};
use bytes::Bytes;
use futures_util::future::Ready;
use http::{header::AUTHORIZATION, Request, Response, StatusCode};
use http::{header::AUTHORIZATION, Request, Response, StatusCode, Uri};
use hyper::{Body, Server};
use serde::Deserialize;
use serde_json::json;
@ -17,6 +17,8 @@ use std::{
use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder};
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
mod nest;
#[tokio::test]
async fn hello_world() {
async fn root(_: Request<Body>) -> &'static str {
@ -521,72 +523,6 @@ async fn layer_on_whole_router() {
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn disjunction() {
let api_routes = route(
"/users",
get(|| async { "users#index" }).post(|| async { "users#create" }),
)
.route(
"/users/:id",
get(|params: extract::UrlParamsMap| async move {
format!(
"{}: users#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
}),
)
.route(
"/games/:id",
get(|params: extract::UrlParamsMap| async move {
format!(
"{}: games#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
}),
);
let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "hi");
let res = client
.get(format!("http://{}/v0/api/users", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "users#index");
let res = client
.get(format!("http://{}/v0/api/users/123", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: users#show (123)");
let res = client
.get(format!("http://{}/v0/api/games/123", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: games#show (123)");
}
#[tokio::test]
async fn typed_header() {
use extract::TypedHeader;
@ -716,33 +652,6 @@ async fn wrong_method_handler() {
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn wrong_method_nest() {
let nested_app = route("/", get(|| async {}));
let app = crate::routing::nest("/", nested_app);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.patch(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn wrong_method_service() {
#[derive(Clone)]

123
src/tests/nest.rs Normal file
View file

@ -0,0 +1,123 @@
use super::*;
#[tokio::test]
async fn nesting_apps() {
let api_routes = route(
"/users",
get(|| async { "users#index" }).post(|| async { "users#create" }),
)
.route(
"/users/:id",
get(|params: extract::UrlParamsMap| async move {
format!(
"{}: users#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
}),
)
.route(
"/games/:id",
get(|params: extract::UrlParamsMap| async move {
format!(
"{}: games#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
}),
);
let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "hi");
let res = client
.get(format!("http://{}/v0/api/users", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "users#index");
let res = client
.get(format!("http://{}/v0/api/users/123", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: users#show (123)");
let res = client
.get(format!("http://{}/v0/api/games/123", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: games#show (123)");
}
#[tokio::test]
async fn wrong_method_nest() {
let nested_app = route("/", get(|| async {}));
let app = crate::routing::nest("/", nested_app);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.patch(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn nesting_at_root() {
let app = nest("/", get(|uri: Uri| async move { uri.to_string() }));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/");
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo");
let res = client
.get(format!("http://{}/foo/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar");
}