Skip to content

Commit

Permalink
docs, tests, and nesting for matchit
Browse files Browse the repository at this point in the history
  • Loading branch information
mladedav committed Mar 22, 2024
1 parent 1861e1c commit 178eeba
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 20 deletions.
43 changes: 43 additions & 0 deletions axum/src/docs/routing/without_v07_checks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Turn off checks for compatibility with route matching syntax from 0.7.

This allows usage of paths starting with a colon `:` or an asterisk `*` which are otherwise prohibited.

# Example

```rust
use axum::{
routing::get,
Router,
};

let app = Router::<()>::new()
.without_v07_checks()
.route("/:colon", get(|| async {}))
.route("/*asterisk", get(|| async {}));

// Our app now accepts
// - GET /:colon
// - GET /*asterisk
# let _: Router = app;
```

Adding such routes without calling this method first will panic.

```rust,should_panic
use axum::{
routing::get,
Router,
};
// This panics...
let app = Router::<()>::new()
.route("/:colon", get(|| async {}));
```

# Merging

When two routers are merged, v0.7 checks are disabled if both of the two routers had them also disabled.

# Nesting

Each router needs to have the checks explicitly disabled. Nesting a router with the checks either enabled or disabled has no effect on the outer router.
40 changes: 40 additions & 0 deletions axum/src/extract/matched_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,44 @@ mod tests {
let res = client.get("/foo/bar").await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn matching_colon() {
let app = Router::new().without_v07_checks().route(
"/:foo",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

let client = TestClient::new(app);

let res = client.get("/:foo").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "/:foo");

let res = client.get("/:bar").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);

let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}

#[crate::test]
async fn matching_asterisk() {
let app = Router::new().without_v07_checks().route(
"/*foo",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

let client = TestClient::new(app);

let res = client.get("/*foo").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "/*foo");

let res = client.get("/*bar").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);

let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}
7 changes: 7 additions & 0 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ where
}
}

#[doc = include_str!("../docs/routing/without_v07_checks.md")]
pub fn without_v07_checks(self) -> Self {
self.tap_inner_mut(|this| {
this.path_router.without_v07_checks();
})
}

#[doc = include_str!("../docs/routing/route.md")]
#[track_caller]
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
Expand Down
84 changes: 64 additions & 20 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
routes: HashMap<RouteId, Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}

impl<S> PathRouter<S, true>
Expand All @@ -32,26 +33,56 @@ where
}
}

fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
return Err("Paths must start with a `/`");
}

if v7_checks {
validate_v07_paths(path)?;
}

Ok(())
}

fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
path.split('/')
.find_map(|segment| {
if segment.starts_with(':') {
Some(Err(
"Path segments must not start with `:`. For capture groups, use \
`{capture}`. If you meant to literally match a segment starting with \
a colon, call `without_v07_checks` on the router.",
))
} else if segment.starts_with('*') {
Some(Err(
"Path segments must not start with `*`. For wildcard capture, use \
`{*wildcard}`. If you meant to literally match a segment starting with \
an asterisk, call `without_v07_checks` on the router.",
))
} else {
None
}
})
.unwrap_or(Ok(()))
}

impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
where
S: Clone + Send + Sync + 'static,
{
pub(super) fn without_v07_checks(&mut self) {
self.v7_checks = false;
}

pub(super) fn route(
&mut self,
path: &str,
method_router: MethodRouter<S>,
) -> Result<(), Cow<'static, str>> {
fn validate_path(path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
return Err("Paths must start with a `/`");
}

Ok(())
}

validate_path(path)?;
validate_path(self.v7_checks, path)?;

let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
Expand Down Expand Up @@ -97,11 +128,7 @@ where
path: &str,
endpoint: Endpoint<S>,
) -> Result<(), Cow<'static, str>> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
} else if !path.starts_with('/') {
return Err("Paths must start with a `/`".into());
}
validate_path(self.v7_checks, path)?;

let id = self.next_route_id();
self.set_node(path, id)?;
Expand All @@ -128,8 +155,12 @@ where
routes,
node,
prev_route_id: _,
v7_checks,
} = other;

// If either of the two did not allow paths starting with `:` or `*`, do not allow them for the merged router either.
self.v7_checks |= v7_checks;

for (id, route) in routes {
let path = node
.route_id_to_path
Expand Down Expand Up @@ -165,12 +196,14 @@ where
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(path_to_nest_at);
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);

let PathRouter {
routes,
node,
prev_route_id: _,
// Ignore the configuration of the nested router
v7_checks: _,
} = router;

for (id, endpoint) in routes {
Expand Down Expand Up @@ -208,7 +241,7 @@ where
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let path = validate_nest_path(path_to_nest_at);
let path = validate_nest_path(self.v7_checks, path_to_nest_at);
let prefix = path;

let path = if path.ends_with('/') {
Expand Down Expand Up @@ -258,6 +291,7 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}

Expand Down Expand Up @@ -290,6 +324,7 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}

Expand All @@ -312,6 +347,7 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}

Expand Down Expand Up @@ -394,6 +430,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
routes: Default::default(),
node: Default::default(),
prev_route_id: RouteId(0),
v7_checks: true,
}
}
}
Expand All @@ -413,6 +450,7 @@ impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
routes: self.routes.clone(),
node: self.node.clone(),
prev_route_id: self.prev_route_id,
v7_checks: self.v7_checks,
}
}
}
Expand Down Expand Up @@ -459,16 +497,22 @@ impl fmt::Debug for Node {
}

#[track_caller]
fn validate_nest_path(path: &str) -> &str {
fn validate_nest_path(v7_checks: bool, path: &str) -> &str {
if path.is_empty() {
// nesting at `""` and `"/"` should mean the same thing
return "/";
}

if path.contains('*') {
if path.split('/').any(|segment| {
segment.starts_with("{*") && segment.ends_with('}') && !segment.ends_with("}}")
}) {
panic!("Invalid route: nested routes cannot contain wildcards (*)");
}

if v7_checks {
validate_v07_paths(path).unwrap();
}

path
}

Expand Down
16 changes: 16 additions & 0 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,19 @@ async fn locks_mutex_very_little() {
assert_eq!(num, 1);
}
}

#[crate::test]
#[should_panic(
expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router."
)]
async fn colon_in_route() {
_ = Router::<()>::new().route("/:foo", get(|| async move {}));
}

#[crate::test]
#[should_panic(
expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router."
)]
async fn asterisk_in_route() {
_ = Router::<()>::new().route("/*foo", get(|| async move {}));
}
16 changes: 16 additions & 0 deletions axum/src/routing/tests/nest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,19 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/");
nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/");
nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a");
nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/");

#[crate::test]
#[should_panic(
expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router."
)]
async fn colon_in_route() {
_ = Router::<()>::new().nest("/:foo", Router::new());
}

#[crate::test]
#[should_panic(
expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router."
)]
async fn asterisk_in_route() {
_ = Router::<()>::new().nest("/*foo", Router::new());
}

0 comments on commit 178eeba

Please sign in to comment.