From 38037222fbb9af9d12ca287e178e6ae91df2264c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 14 May 2021 19:33:29 +0200 Subject: [PATCH] feat(transport): provide generic access to connect info --- examples/src/uds/server.rs | 23 ++++- tests/integration_tests/tests/connect_info.rs | 52 ++++++++++ tonic/src/request.rs | 55 +++++++++-- tonic/src/transport/server/conn.rs | 92 +++++++++++++----- tonic/src/transport/server/incoming.rs | 16 ++- tonic/src/transport/server/mod.rs | 35 ++++--- tonic/src/transport/service/io.rs | 97 ++++++++++++++----- 7 files changed, 292 insertions(+), 78 deletions(-) create mode 100644 tests/integration_tests/tests/connect_info.rs diff --git a/examples/src/uds/server.rs b/examples/src/uds/server.rs index 07be063c9..77b65b49d 100644 --- a/examples/src/uds/server.rs +++ b/examples/src/uds/server.rs @@ -26,6 +26,11 @@ impl Greeter for MyGreeter { ) -> Result, Status> { println!("Got a request: {:?}", request); + let conn_info = request + .connect_info_from_incoming::() + .unwrap(); + dbg!(&conn_info); + let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name), }; @@ -64,6 +69,7 @@ async fn main() -> Result<(), Box> { mod unix { use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, }; @@ -73,7 +79,22 @@ mod unix { #[derive(Debug)] pub struct UnixStream(pub tokio::net::UnixStream); - impl Connected for UnixStream {} + impl Connected for UnixStream { + type ConnectInfo = UdsConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + UdsConnectInfo { + peer_addr: self.0.peer_addr().ok().map(Arc::new), + peer_cred: self.0.peer_cred().ok(), + } + } + } + + #[derive(Clone, Debug)] + pub struct UdsConnectInfo { + pub peer_addr: Option>, + pub peer_cred: Option, + } impl AsyncRead for UnixStream { fn poll_read( diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs new file mode 100644 index 000000000..92a2798f7 --- /dev/null +++ b/tests/integration_tests/tests/connect_info.rs @@ -0,0 +1,52 @@ +use futures_util::FutureExt; +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::{ + transport::{Endpoint, Server}, + Request, Response, Status, +}; + +#[tokio::test] +async fn getting_connect_info() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + match req.metadata().get("user-agent") { + Some(_) => Ok(Response::new(Output {})), + None => Err(Status::internal("user-agent header is missing")), + } + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1400") + .user_agent("my-client") + .expect("valid user agent") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} diff --git a/tonic/src/request.rs b/tonic/src/request.rs index e0a033ef6..91bab9e37 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,6 +1,8 @@ use crate::metadata::{MetadataMap, MetadataValue}; +#[cfg(all(feature = "transport", feature = "tls"))] +use crate::transport::server::TlsConnectInfo; #[cfg(feature = "transport")] -use crate::transport::Certificate; +use crate::transport::{server::TcpConnectInfo, Certificate}; use crate::Extensions; use futures_core::Stream; #[cfg(feature = "transport")] @@ -15,13 +17,6 @@ pub struct Request { extensions: Extensions, } -#[derive(Clone)] -pub(crate) struct ConnectionInfo { - pub(crate) remote_addr: Option, - #[cfg(feature = "transport")] - pub(crate) peer_certs: Option>>, -} - /// Trait implemented by RPC request types. /// /// Types implementing this trait can be used as arguments to client RPC @@ -203,7 +198,28 @@ impl Request { /// does not implement `Connected`. This currently, /// only works on the server side. pub fn remote_addr(&self) -> Option { - self.get::()?.remote_addr + #[cfg(feature = "transport")] + { + #[cfg(feature = "tls")] + { + self.get::() + .and_then(|i| i.remote_addr()) + .or_else(|| { + self.get::>() + .and_then(|i| i.get_ref().remote_addr()) + }) + } + + #[cfg(not(feature = "tls"))] + { + self.get::().and_then(|i| i.remote_addr()) + } + } + + #[cfg(not(feature = "transport"))] + { + None + } } /// Get the peer certificates of the connected client. @@ -215,9 +231,19 @@ impl Request { #[cfg(feature = "transport")] #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub fn peer_certs(&self) -> Option>> { - self.get::()?.peer_certs.clone() + #[cfg(feature = "tls")] + { + self.get::>() + .and_then(|i| i.peer_certs()) + } + + #[cfg(not(feature = "tls"))] + { + None + } } + #[allow(dead_code)] pub(crate) fn get(&self) -> Option<&I> { self.extensions.get::() } @@ -308,6 +334,15 @@ impl Request { pub fn extensions_mut(&mut self) -> &mut Extensions { &mut self.extensions } + + /// TODO(david) + #[cfg(feature = "transport")] + pub fn connect_info_from_incoming(&self) -> Option<&C::ConnectInfo> + where + C: crate::transport::server::Connected, + { + self.get() + } } impl IntoRequest for T { diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index f5bbcfc08..4994f102d 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,7 +1,11 @@ -use crate::transport::Certificate; use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; + +#[cfg(feature = "tls")] +use crate::transport::Certificate; +#[cfg(feature = "tls")] +use std::sync::Arc; #[cfg(feature = "tls")] use tokio_rustls::{rustls::Session, server::TlsStream}; @@ -11,48 +15,90 @@ use tokio_rustls::{rustls::Session, server::TlsStream}; /// custom IO types that can still provide the same connection /// metadata. pub trait Connected { - /// Return the remote address this IO resource is connected too. - fn remote_addr(&self) -> Option { - None - } + /// TODO(david) + // all these bounds are necessary to set this as a request extension + type ConnectInfo: Clone + Send + Sync + 'static; + + /// TODO(david) + fn connect_info(&self) -> Self::ConnectInfo; +} + +/// TODO(david) +#[derive(Debug, Clone)] +pub struct TcpConnectInfo { + remote_addr: Option, +} - /// Return the set of connected peer TLS certificates. - fn peer_certs(&self) -> Option> { - None +impl TcpConnectInfo { + /// TODO(david) + pub fn remote_addr(&self) -> Option { + self.remote_addr } } impl Connected for AddrStream { - fn remote_addr(&self) -> Option { - Some(self.remote_addr()) + type ConnectInfo = TcpConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + TcpConnectInfo { + remote_addr: Some(self.remote_addr()), + } } } impl Connected for TcpStream { - fn remote_addr(&self) -> Option { - self.peer_addr().ok() + type ConnectInfo = TcpConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + TcpConnectInfo { + remote_addr: self.peer_addr().ok(), + } } } #[cfg(feature = "tls")] -impl Connected for TlsStream { - fn remote_addr(&self) -> Option { - let (inner, _) = self.get_ref(); +impl Connected for TlsStream +where + T: Connected, +{ + type ConnectInfo = TlsConnectInfo; - inner.remote_addr() - } - - fn peer_certs(&self) -> Option> { - let (_, session) = self.get_ref(); + fn connect_info(&self) -> Self::ConnectInfo { + let (inner, session) = self.get_ref(); + let inner = inner.connect_info(); - if let Some(certs) = session.get_peer_certificates() { + let certs = if let Some(certs) = session.get_peer_certificates() { let certs = certs .into_iter() .map(|c| Certificate::from_pem(c.0)) .collect(); - Some(certs) + Some(Arc::new(certs)) } else { None - } + }; + + TlsConnectInfo { inner, certs } + } +} + +/// TODO(david) +#[cfg(feature = "tls")] +#[derive(Debug, Clone)] +pub struct TlsConnectInfo { + inner: T, + certs: Option>>, +} + +/// TODO(david) +#[cfg(feature = "tls")] +impl TlsConnectInfo { + /// TODO(david) + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// TODO(david) + pub fn peer_certs(&self) -> Option>> { + self.certs.clone() } } diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 4ead628f0..e4731c696 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -18,7 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub(crate) fn tcp_incoming( incoming: impl Stream>, _server: Server, -) -> impl Stream> +) -> impl Stream, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, @@ -26,10 +26,8 @@ where async_stream::try_stream! { futures_util::pin_mut!(incoming); - while let Some(stream) = incoming.try_next().await? { - - yield ServerIo::new(stream); + yield ServerIo::new_io(stream); } } } @@ -38,7 +36,7 @@ where pub(crate) fn tcp_incoming( incoming: impl Stream>, server: Server, -) -> impl Stream> +) -> impl Stream, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, @@ -57,12 +55,12 @@ where let accept = tokio::spawn(async move { let io = tls.accept(stream).await?; - Ok(ServerIo::new(io)) + Ok(ServerIo::new_tls_io(io)) }); tasks.push(accept); } else { - yield ServerIo::new(stream); + yield ServerIo::new_io(stream); } } @@ -86,7 +84,7 @@ where async fn select( incoming: &mut (impl Stream> + Unpin), tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered< - tokio::task::JoinHandle>, + tokio::task::JoinHandle, crate::Error>>, >, ) -> SelectOutput where @@ -124,7 +122,7 @@ where #[cfg(feature = "tls")] enum SelectOutput { Incoming(A), - Io(ServerIo), + Io(ServerIo), Err(crate::Error), Done, } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7417532ac..0c5673b6c 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -7,10 +7,13 @@ mod recover_error; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; -pub use conn::Connected; +pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; +#[cfg(feature = "tls")] +pub use conn::TlsConnectInfo; + #[cfg(feature = "tls")] use super::service::TlsAcceptor; @@ -24,7 +27,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, Or, Routes, ServerIo}; -use crate::{body::BoxBody, request::ConnectionInfo}; +use crate::body::BoxBody; use futures_core::Stream; use futures_util::{ future::{self, Either as FutureEither, MapErr}, @@ -35,6 +38,7 @@ use hyper::{server::accept, Body}; use std::{ fmt, future::Future, + marker::PhantomData, net::SocketAddr, sync::Arc, task::{Context, Poll}, @@ -372,6 +376,7 @@ impl Server { S::Error: Into + Send, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, { @@ -397,6 +402,7 @@ impl Server { concurrency_limit, timeout, trace_interceptor, + _io: PhantomData, }; let server = hyper::Server::builder(incoming) @@ -554,6 +560,7 @@ where where I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, { self.server @@ -575,6 +582,7 @@ where where I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, { @@ -595,16 +603,17 @@ impl fmt::Debug for Server { } } -struct Svc { +struct Svc { inner: S, trace_interceptor: Option, - conn_info: ConnectionInfo, + conn_info: C, } -impl Service> for Svc +impl Service> for Svc where S: Service, Response = Response>, S::Error: Into, + C: Clone + Send + Sync + 'static, { type Response = Response; type Error = crate::Error; @@ -637,24 +646,27 @@ where } } -impl fmt::Debug for Svc { +impl fmt::Debug for Svc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Svc").finish() } } -struct MakeSvc { +struct MakeSvc { concurrency_limit: Option, timeout: Option, inner: S, trace_interceptor: Option, + _io: PhantomData IO>, } -impl Service<&ServerIo> for MakeSvc +impl Service<&ServerIo> for MakeSvc where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, + IO: Connected, + IO::ConnectInfo: Clone + Send + Sync + 'static, { type Response = BoxService; type Error = crate::Error; @@ -664,11 +676,8 @@ where Ok(()).into() } - fn call(&mut self, io: &ServerIo) -> Self::Future { - let conn_info = crate::request::ConnectionInfo { - remote_addr: io.remote_addr(), - peer_certs: io.peer_certs().map(Arc::new), - }; + fn call(&mut self, io: &ServerIo) -> Self::Future { + let conn_info = io.as_ref().connect_info(); let svc = self.inner.clone(); let concurrency_limit = self.concurrency_limit; diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 761c8ece9..17caf45f4 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,10 +1,11 @@ -use crate::transport::{server::Connected, Certificate}; +use crate::transport::server::Connected; use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; -use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[cfg(feature = "tls")] +use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: AsyncRead + AsyncWrite + Send + 'static @@ -27,7 +28,16 @@ impl Connection for BoxedIo { } } -impl Connected for BoxedIo {} +impl Connected for BoxedIo { + type ConnectInfo = NoneConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + NoneConnectInfo + } +} + +#[derive(Copy, Clone)] +pub(crate) struct NoneConnectInfo; impl AsyncRead for BoxedIo { fn poll_read( @@ -57,52 +67,95 @@ impl AsyncWrite for BoxedIo { } } -pub(in crate::transport) trait ConnectedIo: Io + Connected {} - -impl ConnectedIo for T where T: Io + Connected {} +pub(crate) enum ServerIo { + Io(IO), + #[cfg(feature = "tls")] + TlsIo(TlsStream), +} -pub(crate) struct ServerIo(Pin>); +impl ServerIo { + pub(in crate::transport) fn new_io(io: IO) -> Self { + Self::Io(io) + } -impl ServerIo { - pub(in crate::transport) fn new(io: I) -> Self { - ServerIo(Box::pin(io)) + #[cfg(feature = "tls")] + pub(in crate::transport) fn new_tls_io(io: TlsStream) -> Self { + Self::TlsIo(io) } -} -impl Connected for ServerIo { - fn remote_addr(&self) -> Option { - (&*self.0).remote_addr() + pub(crate) fn as_ref(&self) -> &IO { + match self { + Self::Io(io) => io, + #[cfg(feature = "tls")] + Self::TlsIo(io) => { + let (io, _) = io.get_ref(); + io + } + } } +} - fn peer_certs(&self) -> Option> { - (&self.0).peer_certs() +#[cfg(feature = "tls")] +impl Connected for ServerIo +where + IO: Connected, + TlsStream: Connected, +{ + type ConnectInfo = IO::ConnectInfo; + fn connect_info(&self) -> Self::ConnectInfo { + match self { + Self::Io(io) => io.connect_info(), + Self::TlsIo(io) => io.connect_info(), + } } } -impl AsyncRead for ServerIo { +impl AsyncRead for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_read(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), + } } } -impl AsyncWrite for ServerIo { +impl AsyncWrite for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), + } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_flush(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_flush(cx), + } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_shutdown(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), + } } }