Skip to content

Commit

Permalink
feat(server): add HTTP/1 header read timeout option (#2675)
Browse files Browse the repository at this point in the history
Adds `Server::http1_header_read_timeout(Duration)`. Setting a duration will determine how long a client has to finish sending all the request headers before trigger a timeout test. This can help reduce resource usage when bad actors open connections without sending full requests.

Closes #2457
  • Loading branch information
paolobarbolini authored Nov 18, 2021
1 parent d0b1d9e commit 842c655
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ stream = []
runtime = [
"tcp",
"tokio/rt",
"tokio/time",
]
tcp = [
"socket2",
Expand Down
10 changes: 10 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub(super) enum Kind {
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
Accept,
/// User took too long to send headers
#[cfg(all(feature = "http1", feature = "server", feature = "runtime"))]
HeaderTimeout,
/// Error while reading a body from connection.
#[cfg(any(feature = "http1", feature = "http2", feature = "stream"))]
Body,
Expand Down Expand Up @@ -310,6 +313,11 @@ impl Error {
Error::new_user(User::UnexpectedHeader)
}

#[cfg(all(feature = "http1", feature = "server", feature = "runtime"))]
pub(super) fn new_header_timeout() -> Error {
Error::new(Kind::HeaderTimeout)
}

#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "client")]
pub(super) fn new_user_unsupported_version() -> Error {
Expand Down Expand Up @@ -419,6 +427,8 @@ impl Error {
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
Kind::Accept => "error accepting connection",
#[cfg(all(feature = "http1", feature = "server", feature = "runtime"))]
Kind::HeaderTimeout => "read header from client timeout",
#[cfg(any(feature = "http1", feature = "http2", feature = "stream"))]
Kind::Body => "error reading a body from connection",
#[cfg(any(feature = "http1", feature = "http2"))]
Expand Down
26 changes: 26 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::time::Duration;

use bytes::{Buf, Bytes};
use http::header::{HeaderValue, CONNECTION};
use http::{HeaderMap, Method, Version};
use httparse::ParserConfig;
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(all(feature = "server", feature = "runtime"))]
use tokio::time::Sleep;
use tracing::{debug, error, trace};

use super::io::Buffered;
Expand Down Expand Up @@ -47,6 +50,12 @@ where
keep_alive: KA::Busy,
method: None,
h1_parser_config: ParserConfig::default(),
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: None,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: None,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: false,
preserve_header_case: false,
title_case_headers: false,
h09_responses: false,
Expand Down Expand Up @@ -106,6 +115,11 @@ where
self.state.h09_responses = true;
}

#[cfg(all(feature = "server", feature = "runtime"))]
pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) {
self.state.h1_header_read_timeout = Some(val);
}

#[cfg(feature = "server")]
pub(crate) fn set_allow_half_close(&mut self) {
self.state.allow_half_close = true;
Expand Down Expand Up @@ -178,6 +192,12 @@ where
cached_headers: &mut self.state.cached_headers,
req_method: &mut self.state.method,
h1_parser_config: self.state.h1_parser_config.clone(),
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: self.state.h1_header_read_timeout,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running,
preserve_header_case: self.state.preserve_header_case,
h09_responses: self.state.h09_responses,
#[cfg(feature = "ffi")]
Expand Down Expand Up @@ -798,6 +818,12 @@ struct State {
/// a body or not.
method: Option<Method>,
h1_parser_config: ParserConfig,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: Option<Pin<Box<Sleep>>>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: bool,
preserve_header_case: bool,
title_case_headers: bool,
h09_responses: bool,
Expand Down
38 changes: 37 additions & 1 deletion src/proto/h1/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ use std::fmt;
use std::io::{self, IoSlice};
use std::marker::Unpin;
use std::mem::MaybeUninit;
use std::future::Future;
#[cfg(all(feature = "server", feature = "runtime"))]
use std::time::Duration;

#[cfg(all(feature = "server", feature = "runtime"))]
use tokio::time::Instant;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::{debug, trace};
use tracing::{debug, warn, trace};

use super::{Http1Transaction, ParseContext, ParsedMessage};
use crate::common::buf::BufList;
Expand Down Expand Up @@ -181,6 +186,12 @@ where
cached_headers: parse_ctx.cached_headers,
req_method: parse_ctx.req_method,
h1_parser_config: parse_ctx.h1_parser_config.clone(),
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
preserve_header_case: parse_ctx.preserve_header_case,
h09_responses: parse_ctx.h09_responses,
#[cfg(feature = "ffi")]
Expand All @@ -191,6 +202,16 @@ where
)? {
Some(msg) => {
debug!("parsed {} headers", msg.head.headers.len());

#[cfg(all(feature = "server", feature = "runtime"))]
{
*parse_ctx.h1_header_read_timeout_running = false;

if let Some(h1_header_read_timeout_fut) = parse_ctx.h1_header_read_timeout_fut {
// Reset the timer in order to avoid woken up when the timeout finishes
h1_header_read_timeout_fut.as_mut().reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60));
}
}
return Poll::Ready(Ok(msg));
}
None => {
Expand All @@ -199,6 +220,18 @@ where
debug!("max_buf_size ({}) reached, closing", max);
return Poll::Ready(Err(crate::Error::new_too_large()));
}

#[cfg(all(feature = "server", feature = "runtime"))]
if *parse_ctx.h1_header_read_timeout_running {
if let Some(h1_header_read_timeout_fut) = parse_ctx.h1_header_read_timeout_fut {
if Pin::new( h1_header_read_timeout_fut).poll(cx).is_ready() {
*parse_ctx.h1_header_read_timeout_running = false;

warn!("read header from client timeout");
return Poll::Ready(Err(crate::Error::new_header_timeout()))
}
}
}
}
}
if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
Expand Down Expand Up @@ -693,6 +726,9 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_header_read_timeout: None,
h1_header_read_timeout_fut: &mut None,
h1_header_read_timeout_running: &mut false,
preserve_header_case: false,
h09_responses: false,
#[cfg(feature = "ffi")]
Expand Down
11 changes: 11 additions & 0 deletions src/proto/h1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::pin::Pin;
use std::time::Duration;

use bytes::BytesMut;
use http::{HeaderMap, Method};
use httparse::ParserConfig;
#[cfg(all(feature = "server", feature = "runtime"))]
use tokio::time::Sleep;

use crate::body::DecodedLength;
use crate::proto::{BodyLength, MessageHead};
Expand Down Expand Up @@ -72,6 +77,12 @@ pub(crate) struct ParseContext<'a> {
cached_headers: &'a mut Option<HeaderMap>,
req_method: &'a mut Option<Method>,
h1_parser_config: ParserConfig,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: &'a mut Option<Pin<Box<Sleep>>>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: &'a mut bool,
preserve_header_case: bool,
h09_responses: bool,
#[cfg(feature = "ffi")]
Expand Down
Loading

0 comments on commit 842c655

Please sign in to comment.