Skip to content

Commit

Permalink
feat(transport): add user-agent header to client requests. (#457)
Browse files Browse the repository at this point in the history
* Add a default user-agent header to outgoing requests.
* The user agent can be configured through the `Channel` builder.

fixes #453
  • Loading branch information
alce authored Sep 23, 2020
1 parent e9910d1 commit d4899df
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 2 deletions.
55 changes: 55 additions & 0 deletions tests/integration_tests/tests/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{
transport::{Endpoint, Server},
Request, Response, Status,
};

#[tokio::test]
async fn writes_user_agent_header() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
match req.metadata().get("user-agent") {
Some(_) => Ok(Response::new(Output {})),
None => Err(Status::internal("user-agent header is missing")),
}
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1322".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::delay_for(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1322")
.user_agent("my-client")
.expect("valid user agent")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}
31 changes: 30 additions & 1 deletion tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::ClientTlsConfig;
use crate::transport::service::TlsConnector;
use crate::transport::Error;
use bytes::Bytes;
use http::uri::{InvalidUri, Uri};
use http::{
uri::{InvalidUri, Uri},
HeaderValue,
};
use std::{
convert::{TryFrom, TryInto},
fmt,
Expand All @@ -20,6 +23,7 @@ use tower_make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
pub(crate) user_agent: Option<HeaderValue>,
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
pub(crate) rate_limit: Option<(u64, Duration)>,
Expand Down Expand Up @@ -74,6 +78,30 @@ impl Endpoint {
Ok(Self::from(uri))
}

/// Set a custom user-agent header.
///
/// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
/// It must be a value that can be converted into a valid `http::HeaderValue` or building
/// the endpoint will fail.
/// ```
/// # use tonic::transport::Endpoint;
/// # let mut builder = Endpoint::from_static("https://example.com");
/// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
/// // user-agent: "Greeter tonic/x.x.x"
/// ```
pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
where
T: TryInto<HeaderValue>,
{
user_agent
.try_into()
.map(|ua| Endpoint {
user_agent: Some(ua),
..self
})
.map_err(|_| Error::new_invalid_user_agent())
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -276,6 +304,7 @@ impl From<Uri> for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
timeout: None,
Expand Down
6 changes: 6 additions & 0 deletions tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct ErrorImpl {
pub(crate) enum Kind {
Transport,
InvalidUri,
InvalidUserAgent,
}

impl Error {
Expand All @@ -43,10 +44,15 @@ impl Error {
Error::new(Kind::InvalidUri)
}

pub(crate) fn new_invalid_user_agent() -> Self {
Error::new(Kind::InvalidUserAgent)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
Kind::InvalidUri => "invalid URI",
Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
Expand Down Expand Up @@ -55,6 +55,7 @@ impl Connection {

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.optional_layer(endpoint.timeout.map(TimeoutLayer::new))
.optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod reconnect;
mod router;
#[cfg(feature = "tls")]
mod tls;
mod user_agent;

pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
Expand All @@ -18,3 +19,4 @@ pub(crate) use self::layer::ServiceBuilderExt;
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
pub(crate) use self::user_agent::UserAgent;
70 changes: 70 additions & 0 deletions tonic/src/transport/service/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use http::{header::USER_AGENT, HeaderValue, Request};
use std::task::{Context, Poll};
use tower_service::Service;

const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));

#[derive(Debug)]
pub(crate) struct UserAgent<T> {
inner: T,
user_agent: HeaderValue,
}

impl<T> UserAgent<T> {
pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
let user_agent = user_agent
.map(|value| {
let mut buf = Vec::new();
buf.extend(value.as_bytes());
buf.push(b' ');
buf.extend(TONIC_USER_AGENT.as_bytes());
HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
})
.unwrap_or(HeaderValue::from_static(TONIC_USER_AGENT));

Self { inner, user_agent }
}
}

impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
where
T: Service<Request<ReqBody>>,
{
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
req.headers_mut()
.insert(USER_AGENT, self.user_agent.clone());

self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use super::*;

struct Svc;

#[test]
fn sets_default_if_no_custom_user_agent() {
assert_eq!(
UserAgent::new(Svc, None).user_agent,
HeaderValue::from_static(TONIC_USER_AGENT)
)
}

#[test]
fn prepends_custom_user_agent_to_default() {
assert_eq!(
UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap()
)
}
}

0 comments on commit d4899df

Please sign in to comment.