Skip to content

Commit

Permalink
feat(transport): provide generic access to connect info
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed May 14, 2021
1 parent f613386 commit 3803722
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 78 deletions.
23 changes: 22 additions & 1 deletion examples/src/uds/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);

let conn_info = request
.connect_info_from_incoming::<unix::UnixStream>()
.unwrap();
dbg!(&conn_info);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Expand Down Expand Up @@ -64,6 +69,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

Expand All @@ -73,7 +79,22 @@ mod unix {
#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {}
impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
Expand Down
52 changes: 52 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{
transport::{Endpoint, Server},
Request, Response, Status,
};

#[tokio::test]
async fn getting_connect_info() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
match req.metadata().get("user-agent") {
Some(_) => Ok(Response::new(Output {})),
None => Err(Status::internal("user-agent header is missing")),
}
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1400")
.user_agent("my-client")
.expect("valid user agent")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}
55 changes: 45 additions & 10 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(all(feature = "transport", feature = "tls"))]
use crate::transport::server::TlsConnectInfo;
#[cfg(feature = "transport")]
use crate::transport::Certificate;
use crate::transport::{server::TcpConnectInfo, Certificate};
use crate::Extensions;
use futures_core::Stream;
#[cfg(feature = "transport")]
Expand All @@ -15,13 +17,6 @@ pub struct Request<T> {
extensions: Extensions,
}

#[derive(Clone)]
pub(crate) struct ConnectionInfo {
pub(crate) remote_addr: Option<SocketAddr>,
#[cfg(feature = "transport")]
pub(crate) peer_certs: Option<Arc<Vec<Certificate>>>,
}

/// Trait implemented by RPC request types.
///
/// Types implementing this trait can be used as arguments to client RPC
Expand Down Expand Up @@ -203,7 +198,28 @@ impl<T> Request<T> {
/// does not implement `Connected`. This currently,
/// only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.get::<ConnectionInfo>()?.remote_addr
#[cfg(feature = "transport")]
{
#[cfg(feature = "tls")]
{
self.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
.or_else(|| {
self.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.get_ref().remote_addr())
})
}

#[cfg(not(feature = "tls"))]
{
self.get::<TcpConnectInfo>().and_then(|i| i.remote_addr())
}
}

#[cfg(not(feature = "transport"))]
{
None
}
}

/// Get the peer certificates of the connected client.
Expand All @@ -215,9 +231,19 @@ impl<T> Request<T> {
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.get::<ConnectionInfo>()?.peer_certs.clone()
#[cfg(feature = "tls")]
{
self.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.peer_certs())
}

#[cfg(not(feature = "tls"))]
{
None
}
}

#[allow(dead_code)]
pub(crate) fn get<I: Send + Sync + 'static>(&self) -> Option<&I> {
self.extensions.get::<I>()
}
Expand Down Expand Up @@ -308,6 +334,15 @@ impl<T> Request<T> {
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}

/// TODO(david)
#[cfg(feature = "transport")]
pub fn connect_info_from_incoming<C>(&self) -> Option<&C::ConnectInfo>
where
C: crate::transport::server::Connected,
{
self.get()
}
}

impl<T> IntoRequest<T> for T {
Expand Down
92 changes: 69 additions & 23 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use crate::transport::Certificate;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};

Expand All @@ -11,48 +15,90 @@ use tokio_rustls::{rustls::Session, server::TlsStream};
/// custom IO types that can still provide the same connection
/// metadata.
pub trait Connected {
/// Return the remote address this IO resource is connected too.
fn remote_addr(&self) -> Option<SocketAddr> {
None
}
/// TODO(david)
// all these bounds are necessary to set this as a request extension
type ConnectInfo: Clone + Send + Sync + 'static;

/// TODO(david)
fn connect_info(&self) -> Self::ConnectInfo;
}

/// TODO(david)
#[derive(Debug, Clone)]
pub struct TcpConnectInfo {
remote_addr: Option<SocketAddr>,
}

/// Return the set of connected peer TLS certificates.
fn peer_certs(&self) -> Option<Vec<Certificate>> {
None
impl TcpConnectInfo {
/// TODO(david)
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
}

impl Connected for AddrStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr())
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: Some(self.remote_addr()),
}
}
}

impl Connected for TcpStream {
fn remote_addr(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: self.peer_addr().ok(),
}
}
}

#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
let (inner, _) = self.get_ref();
impl<T> Connected for TlsStream<T>
where
T: Connected,
{
type ConnectInfo = TlsConnectInfo<T::ConnectInfo>;

inner.remote_addr()
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let (_, session) = self.get_ref();
fn connect_info(&self) -> Self::ConnectInfo {
let (inner, session) = self.get_ref();
let inner = inner.connect_info();

if let Some(certs) = session.get_peer_certificates() {
let certs = if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
Some(Arc::new(certs))
} else {
None
}
};

TlsConnectInfo { inner, certs }
}
}

/// TODO(david)
#[cfg(feature = "tls")]
#[derive(Debug, Clone)]
pub struct TlsConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<Certificate>>>,
}

/// TODO(david)
#[cfg(feature = "tls")]
impl<T> TlsConnectInfo<T> {
/// TODO(david)
pub fn get_ref(&self) -> &T {
&self.inner
}

/// TODO(david)
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.certs.clone()
}
}
16 changes: 7 additions & 9 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
_server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
futures_util::pin_mut!(incoming);


while let Some(stream) = incoming.try_next().await? {

yield ServerIo::new(stream);
yield ServerIo::new_io(stream);
}
}
}
Expand All @@ -38,7 +36,7 @@ where
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
Expand All @@ -57,12 +55,12 @@ where

let accept = tokio::spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new(io))
Ok(ServerIo::new_tls_io(io))
});

tasks.push(accept);
} else {
yield ServerIo::new(stream);
yield ServerIo::new_io(stream);
}
}

Expand All @@ -86,7 +84,7 @@ where
async fn select<IO, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered<
tokio::task::JoinHandle<Result<ServerIo, crate::Error>>,
tokio::task::JoinHandle<Result<ServerIo<IO>, crate::Error>>,
>,
) -> SelectOutput<IO>
where
Expand Down Expand Up @@ -124,7 +122,7 @@ where
#[cfg(feature = "tls")]
enum SelectOutput<A> {
Incoming(A),
Io(ServerIo),
Io(ServerIo<A>),
Err(crate::Error),
Done,
}
Expand Down
Loading

0 comments on commit 3803722

Please sign in to comment.