diff --git a/Cargo.toml b/Cargo.toml index 0550331ab..4f7f24f7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ members = [ [replace] 'prost:0.6.1' = { git = "https://github.com/jen20/prost", branch = "jen20/file-descriptor-set" } 'prost-build:0.6.1' = { git = "https://github.com/jen20/prost", branch = "jen20/file-descriptor-set" } +'prost-types:0.6.1' = { git = "https://github.com/jen20/prost", branch = "jen20/file-descriptor-set" } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 0cd42cbae..f4b24b107 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -114,6 +114,10 @@ path = "src/hyper_warp/server.rs" name = "health-server" path = "src/health/server.rs" +[[bin]] +name = "reflection-server" +path = "src/reflection/server.rs" + [[bin]] name = "autoreload-server" path = "src/autoreload/server.rs" @@ -144,6 +148,8 @@ http-body = "0.3" pin-project = "0.4" # Health example tonic-health = { path = "../tonic-health" } +# Reflection example +tonic-reflection = { path = "../tonic-reflection" } listenfd = "0.3" [build-dependencies] diff --git a/examples/build.rs b/examples/build.rs index 97ab2d038..825285e6c 100644 --- a/examples/build.rs +++ b/examples/build.rs @@ -1,5 +1,8 @@ fn main() { - tonic_build::compile_protos("proto/helloworld/helloworld.proto").unwrap(); + tonic_build::configure() + .include_file_descriptor_set(true) + .compile(&["proto/helloworld/helloworld.proto"], &["proto/"]) + .unwrap(); tonic_build::compile_protos("proto/routeguide/route_guide.proto").unwrap(); tonic_build::compile_protos("proto/echo/echo.proto").unwrap(); tonic_build::compile_protos("proto/google/pubsub/pubsub.proto").unwrap(); diff --git a/examples/src/reflection/server.rs b/examples/src/reflection/server.rs new file mode 100644 index 000000000..fd547353c --- /dev/null +++ b/examples/src/reflection/server.rs @@ -0,0 +1,45 @@ +use tonic::transport::Server; +use tonic::{Request, Response, Status}; + +mod proto { + tonic::include_proto!("helloworld"); + + pub(crate) const FILE_DESCRIPTOR_SET: &'static [u8] = tonic::include_file_descriptor_set!(); +} + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl proto::greeter_server::Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = proto::HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build() + .unwrap(); + + let addr = "[::1]:50052".parse().unwrap(); + let greeter = MyGreeter::default(); + + Server::builder() + .add_service(service) + .add_service(proto::greeter_server::GreeterServer::new(greeter)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/tonic-reflection/Cargo.toml b/tonic-reflection/Cargo.toml new file mode 100644 index 000000000..e2d0e3593 --- /dev/null +++ b/tonic-reflection/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "tonic-reflection" +version = "0.1.0" +authors = ["James Nugent "] +edition = "2018" +license = "MIT" +repository = "https://github.com/hyperium/tonic" +homepage = "https://github.com/hyperium/tonic" +description = """ +Server Reflection module of `tonic` gRPC implementation. +""" +readme = "README.md" +categories = ["network-programming", "asynchronous"] +keywords = ["rpc", "grpc", "async", "reflection"] + +[dependencies] +bytes = "0.5" +prost = "0.6" +prost-types = "0.6" +tokio = { version = "0.2", features = ["sync", "stream"] } +tonic = { version = "0.2", path = "../tonic", features = ["codegen", "prost"] } + +[build-dependencies] +tonic-build = { version = "0.2", path = "../tonic-build" } diff --git a/tonic-reflection/build.rs b/tonic-reflection/build.rs new file mode 100644 index 000000000..8050c31e5 --- /dev/null +++ b/tonic-reflection/build.rs @@ -0,0 +1,10 @@ +fn main() -> Result<(), Box> { + tonic_build::configure() + .include_file_descriptor_set(true) + .build_server(true) + .build_client(false) + .format(true) + .compile(&["proto/reflection.proto"], &["proto/"])?; + + Ok(()) +} diff --git a/tonic-reflection/proto/reflection.proto b/tonic-reflection/proto/reflection.proto new file mode 100644 index 000000000..c2da31461 --- /dev/null +++ b/tonic-reflection/proto/reflection.proto @@ -0,0 +1,136 @@ +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Service exported by server reflection + +syntax = "proto3"; + +package grpc.reflection.v1alpha; + +service ServerReflection { + // The reflection service is structured as a bidirectional stream, ensuring + // all related requests go to a single server. + rpc ServerReflectionInfo(stream ServerReflectionRequest) + returns (stream ServerReflectionResponse); +} + +// The message sent by the client when calling ServerReflectionInfo method. +message ServerReflectionRequest { + string host = 1; + // To use reflection service, the client should set one of the following + // fields in message_request. The server distinguishes requests by their + // defined field and then handles them using corresponding methods. + oneof message_request { + // Find a proto file by the file name. + string file_by_filename = 3; + + // Find the proto file that declares the given fully-qualified symbol name. + // This field should be a fully-qualified symbol name + // (e.g. .[.] or .). + string file_containing_symbol = 4; + + // Find the proto file which defines an extension extending the given + // message type with the given field number. + ExtensionRequest file_containing_extension = 5; + + // Finds the tag numbers used by all known extensions of extendee_type, and + // appends them to ExtensionNumberResponse in an undefined order. + // Its corresponding method is best-effort: it's not guaranteed that the + // reflection service will implement this method, and it's not guaranteed + // that this method will provide all extensions. Returns + // StatusCode::UNIMPLEMENTED if it's not implemented. + // This field should be a fully-qualified type name. The format is + // . + string all_extension_numbers_of_type = 6; + + // List the full names of registered services. The content will not be + // checked. + string list_services = 7; + } +} + +// The type name and extension number sent by the client when requesting +// file_containing_extension. +message ExtensionRequest { + // Fully-qualified type name. The format should be . + string containing_type = 1; + int32 extension_number = 2; +} + +// The message sent by the server to answer ServerReflectionInfo method. +message ServerReflectionResponse { + string valid_host = 1; + ServerReflectionRequest original_request = 2; + // The server sets one of the following fields according to the + // message_request in the request. + oneof message_response { + // This message is used to answer file_by_filename, file_containing_symbol, + // file_containing_extension requests with transitive dependencies. + // As the repeated label is not allowed in oneof fields, we use a + // FileDescriptorResponse message to encapsulate the repeated fields. + // The reflection service is allowed to avoid sending FileDescriptorProtos + // that were previously sent in response to earlier requests in the stream. + FileDescriptorResponse file_descriptor_response = 4; + + // This message is used to answer all_extension_numbers_of_type requests. + ExtensionNumberResponse all_extension_numbers_response = 5; + + // This message is used to answer list_services requests. + ListServiceResponse list_services_response = 6; + + // This message is used when an error occurs. + ErrorResponse error_response = 7; + } +} + +// Serialized FileDescriptorProto messages sent by the server answering +// a file_by_filename, file_containing_symbol, or file_containing_extension +// request. +message FileDescriptorResponse { + // Serialized FileDescriptorProto messages. We avoid taking a dependency on + // descriptor.proto, which uses proto2 only features, by making them opaque + // bytes instead. + repeated bytes file_descriptor_proto = 1; +} + +// A list of extension numbers sent by the server answering +// all_extension_numbers_of_type request. +message ExtensionNumberResponse { + // Full name of the base type, including the package name. The format + // is . + string base_type_name = 1; + repeated int32 extension_number = 2; +} + +// A list of ServiceResponse sent by the server answering list_services request. +message ListServiceResponse { + // The information of each service may be expanded in the future, so we use + // ServiceResponse message to encapsulate it. + repeated ServiceResponse service = 1; +} + +// The information of a single service used by ListServiceResponse to answer +// list_services request. +message ServiceResponse { + // Full name of a registered service, including its package name. The format + // is . + string name = 1; +} + +// The error code and error message sent by the server when an error occurs. +message ErrorResponse { + // This field uses the error codes defined in grpc::StatusCode. + int32 error_code = 1; + string error_message = 2; +} \ No newline at end of file diff --git a/tonic-reflection/src/lib.rs b/tonic-reflection/src/lib.rs new file mode 100644 index 000000000..3d5f4fc03 --- /dev/null +++ b/tonic-reflection/src/lib.rs @@ -0,0 +1,25 @@ +//! A `tonic` based gRPC Server Reflection implementation. + +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![doc( + html_logo_url = "https://github.com/hyperium/tonic/raw/master/.github/assets/tonic-docs.png" +)] +#![doc(html_root_url = "https://docs.rs/tonic-reflection/0.1.0")] +#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")] +#![doc(test(no_crate_inject, attr(deny(rust_2018_idioms))))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +mod proto { + #![allow(unreachable_pub)] + tonic::include_proto!("grpc.reflection.v1alpha"); + + pub(crate) const FILE_DESCRIPTOR_SET: &'static [u8] = tonic::include_file_descriptor_set!(); +} + +/// Implementation of the server component of gRPC Server Reflection. +pub mod server; diff --git a/tonic-reflection/src/server.rs b/tonic-reflection/src/server.rs new file mode 100644 index 000000000..efc12d5c9 --- /dev/null +++ b/tonic-reflection/src/server.rs @@ -0,0 +1,361 @@ +use crate::proto::server_reflection_request::MessageRequest; +use crate::proto::server_reflection_response::MessageResponse; +use crate::proto::server_reflection_server::{ServerReflection, ServerReflectionServer}; +use crate::proto::{ + FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest, ServerReflectionResponse, + ServiceResponse, +}; +use prost::{DecodeError, Message}; +use prost_types::{ + DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, + FileDescriptorSet, +}; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; +use tokio::stream::StreamExt; +use tokio::sync::mpsc; +use tonic::{Request, Response, Status, Streaming}; + +/// Represents an error in the construction of a gRPC Reflection Service. +#[derive(Debug)] +pub enum Error { + /// An error was encountered decoding a `prost_types::FileDescriptorSet` from a buffer. + DecodeError(prost::DecodeError), + /// An invalid `prost_types::FileDescriptorProto` was encountered. + InvalidFileDescriptorSet(String), +} + +impl From for Error { + fn from(e: DecodeError) -> Self { + Error::DecodeError(e) + } +} + +impl std::error::Error for Error {} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"), + Error::InvalidFileDescriptorSet(s) => { + f.write_fmt(format_args!("invalid FileDescriptorSet - {}", s)) + } + } + } +} + +/// A builder used to construct a gRPC Reflection Service. +#[derive(Debug)] +pub struct Builder<'b> { + file_descriptor_sets: Vec, + encoded_file_descriptor_sets: Vec<&'b [u8]>, + include_reflection_service: bool, + + service_names: Vec, + symbols: HashMap>, +} + +impl<'b> Builder<'b> { + /// Create a new builder that can configure a gRPC Reflection Service. + pub fn configure() -> Self { + Builder { + file_descriptor_sets: Vec::new(), + encoded_file_descriptor_sets: Vec::new(), + include_reflection_service: true, + + service_names: Vec::new(), + symbols: HashMap::new(), + } + } + + /// Registers an instance of `prost_types::FileDescriptorSet` with the gRPC Reflection + /// Service builder. + pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self { + self.file_descriptor_sets.push(file_descriptor_set); + self + } + + /// Registers a byte slice containing an encoded `prost_types::FileDescriptorSet` with + /// the gRPC Reflection Service builder. + pub fn register_encoded_file_descriptor_set( + mut self, + encoded_file_descriptor_set: &'b [u8], + ) -> Self { + self.encoded_file_descriptor_sets + .push(encoded_file_descriptor_set); + self + } + + /// Serve the gRPC Refection Service descriptor via the Reflection Service. This is enabled + /// by default - set `include` to false to disable. + pub fn include_reflection_service(mut self, include: bool) -> Self { + self.include_reflection_service = include; + self + } + + /// Build a gRPC Reflection Service to be served via Tonic. + pub fn build(mut self) -> Result, Error> { + if self.include_reflection_service { + self = self.register_encoded_file_descriptor_set(crate::proto::FILE_DESCRIPTOR_SET); + } + + for encoded in &self.encoded_file_descriptor_sets { + let decoded = FileDescriptorSet::decode(*encoded)?; + self.file_descriptor_sets.push(decoded); + } + + let all_fds = self.file_descriptor_sets.clone(); + let mut files: HashMap> = HashMap::new(); + + for fds in all_fds { + for fd in fds.file { + let name = match fd.name.clone() { + None => { + return Err(Error::InvalidFileDescriptorSet("missing name".to_string())); + } + Some(n) => n, + }; + + if files.contains_key(&name) { + continue; + } + + let fd = Arc::new(fd); + files.insert(name, fd.clone()); + + self.process_file(fd)?; + } + } + + let service_names = self + .service_names + .iter() + .map(|name| ServiceResponse { name: name.clone() }) + .collect(); + + Ok(ServerReflectionServer::new(ReflectionService { + state: Arc::new(ReflectionServiceState { + service_names, + files, + symbols: self.symbols, + }), + })) + } + + fn process_file(&mut self, fd: Arc) -> Result<(), Error> { + let prefix = &fd.package.clone().unwrap_or_default(); + + for msg in &fd.message_type { + self.process_message(fd.clone(), &prefix, msg)?; + } + + for en in &fd.enum_type { + self.process_enum(fd.clone(), &prefix, en)?; + } + + for service in &fd.service { + let service_name = extract_name(&prefix, "service", service.name.as_ref())?; + self.service_names.push(service_name.clone()); + self.symbols.insert(service_name.clone(), fd.clone()); + + for method in &service.method { + let method_name = extract_name(&service_name, "method", method.name.as_ref())?; + self.symbols.insert(method_name, fd.clone()); + } + } + + Ok(()) + } + + fn process_message( + &mut self, + fd: Arc, + prefix: &str, + msg: &DescriptorProto, + ) -> Result<(), Error> { + let message_name = extract_name(prefix, "message", msg.name.as_ref())?; + self.symbols.insert(message_name.clone(), fd.clone()); + + for nested in &msg.nested_type { + self.process_message(fd.clone(), &message_name, nested)?; + } + + for en in &msg.enum_type { + self.process_enum(fd.clone(), &message_name, en)?; + } + + for field in &msg.field { + self.process_field(fd.clone(), &message_name, field)?; + } + + for oneof in &msg.oneof_decl { + let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?; + self.symbols.insert(oneof_name, fd.clone()); + } + + Ok(()) + } + + fn process_enum( + &mut self, + fd: Arc, + prefix: &str, + en: &EnumDescriptorProto, + ) -> Result<(), Error> { + let enum_name = extract_name(prefix, "enum", en.name.as_ref())?; + self.symbols.insert(enum_name.clone(), fd.clone()); + + for value in &en.value { + let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?; + self.symbols.insert(value_name, fd.clone()); + } + + Ok(()) + } + + fn process_field( + &mut self, + fd: Arc, + prefix: &str, + field: &FieldDescriptorProto, + ) -> Result<(), Error> { + let field_name = extract_name(prefix, "field", field.name.as_ref())?; + self.symbols.insert(field_name, fd.clone()); + Ok(()) + } +} + +fn extract_name( + prefix: &str, + name_type: &str, + maybe_name: Option<&String>, +) -> Result { + match maybe_name { + None => Err(Error::InvalidFileDescriptorSet(format!( + "missing {} name", + name_type + ))), + Some(name) => { + if prefix.is_empty() { + Ok(name.to_string()) + } else { + Ok(format!("{}.{}", prefix, name)) + } + } + } +} + +#[derive(Debug)] +struct ReflectionServiceState { + service_names: Vec, + files: HashMap>, + symbols: HashMap>, +} + +impl ReflectionServiceState { + fn list_services(&self) -> MessageResponse { + MessageResponse::ListServicesResponse(ListServiceResponse { + service: self.service_names.clone(), + }) + } + + fn symbol_by_name(&self, symbol: &str) -> Result { + match self.symbols.get(symbol) { + None => Err(Status::not_found(format!("symbol '{}' not found", symbol))), + Some(fd) => { + let mut encoded_fd = Vec::new(); + if let Err(_) = fd.clone().encode(&mut encoded_fd) { + return Err(Status::internal("encoding error")); + }; + + Ok(MessageResponse::FileDescriptorResponse( + FileDescriptorResponse { + file_descriptor_proto: vec![encoded_fd], + }, + )) + } + } + } + + fn file_by_filename(&self, filename: &str) -> Result { + match self.files.get(filename) { + None => Err(Status::not_found(format!("file '{}' not found", filename))), + Some(fd) => { + let mut encoded_fd = Vec::new(); + if let Err(_) = fd.clone().encode(&mut encoded_fd) { + return Err(Status::internal("encoding error")); + } + + Ok(MessageResponse::FileDescriptorResponse( + FileDescriptorResponse { + file_descriptor_proto: vec![encoded_fd], + }, + )) + } + } + } +} + +#[derive(Debug)] +struct ReflectionService { + state: Arc, +} + +#[tonic::async_trait] +impl ServerReflection for ReflectionService { + type ServerReflectionInfoStream = mpsc::Receiver>; + + async fn server_reflection_info( + &self, + req: Request>, + ) -> Result, Status> { + let mut req_rx = req.into_inner(); + let (mut resp_tx, resp_rx) = mpsc::channel::>(1); + + let state = self.state.clone(); + + tokio::spawn(async move { + while let Some(req) = req_rx.next().await { + let req = match req { + Ok(req) => req, + Err(_) => { + return; + } + }; + + let resp_msg = match req.message_request.clone() { + None => Err(Status::invalid_argument("invalid MessageRequest")), + Some(msg) => match msg { + MessageRequest::FileByFilename(s) => state.file_by_filename(&s), + MessageRequest::FileContainingSymbol(s) => state.symbol_by_name(&s), + MessageRequest::FileContainingExtension(_) => { + Err(Status::not_found("extensions are not supported")) + } + MessageRequest::AllExtensionNumbersOfType(_) => { + Err(Status::not_found("extensions are not supported")) + } + MessageRequest::ListServices(_) => Ok(state.list_services()), + }, + }; + + match resp_msg { + Ok(resp_msg) => { + let resp = ServerReflectionResponse { + valid_host: req.host.clone(), + original_request: Some(req.clone()), + message_response: Some(resp_msg), + }; + resp_tx.send(Ok(resp)).await.expect("send"); + } + Err(status) => { + resp_tx.send(Err(status)).await.expect("send"); + return; + } + } + } + }); + + Ok(Response::new(resp_rx)) + } +}