diff --git a/protobuf/src/core.rs b/protobuf/src/core.rs index 9efb83221..c0c2731b2 100644 --- a/protobuf/src/core.rs +++ b/protobuf/src/core.rs @@ -166,7 +166,7 @@ pub trait Message: fmt::Debug + Clear + Send + Sync + ProtobufValue { Self: Sized; } -impl Message { +impl dyn Message { pub fn downcast_box(self: Box) -> Result, Box> { if self.as_any().is::() { unsafe { @@ -177,6 +177,21 @@ impl Message { Err(self) } } + + pub fn downcast_ref<'a, M: Message + 'a>(&'a self) -> Option<&'a M> { + self.as_any().downcast_ref::() + } + + pub fn downcast_mut<'a, M: Message + 'a>(&'a mut self) -> Option<&'a mut M> { + if self.as_any().is::() { + unsafe { + Some(&mut *(self as *mut dyn Message as *mut M)) + } + } else { + None + } + + } } impl Clone for Box { @@ -194,21 +209,6 @@ impl PartialEq for Box { } } -pub fn message_down_cast_ref<'a, M: Message + 'a>(m: &'a Message) -> Option<&'a M> { - m.as_any().downcast_ref::() -} - -pub fn message_down_cast_mut<'a, M: Message + 'a>(m: &'a mut Message) -> Option<&'a mut M> { - if m.as_any().is::() { - unsafe { - Some(&mut *(m as *mut dyn Message as *mut M)) - } - } else { - None - } - -} - /// Parse message from stream. pub fn parse_from(is: &mut CodedInputStream) -> ProtobufResult { let mut r: M = Message::new(); diff --git a/protobuf/src/json/parse.rs b/protobuf/src/json/parse.rs index 44263f9a6..74ee6fbbe 100644 --- a/protobuf/src/json/parse.rs +++ b/protobuf/src/json/parse.rs @@ -48,7 +48,7 @@ use well_known_types::UInt32Value; use well_known_types::UInt64Value; use well_known_types::Value; use well_known_types::Value_oneof_kind; -use core::message_down_cast_mut; + #[derive(Debug)] pub enum ParseError { @@ -571,67 +571,67 @@ impl<'a> Parser<'a> { } fn merge_inner(&mut self, message: &mut Message) -> ParseResult<()> { - if let Some(duration) = message_down_cast_mut(message) { + if let Some(duration) = message.downcast_mut() { return self.merge_wk_duration(duration); } - if let Some(timestamp) = message_down_cast_mut(message) { + if let Some(timestamp) = message.downcast_mut() { return self.merge_wk_timestamp(timestamp); } - if let Some(field_mask) = message_down_cast_mut(message) { + if let Some(field_mask) = message.downcast_mut() { return self.merge_wk_field_mask(field_mask); } - if let Some(value) = message_down_cast_mut(message) { + if let Some(value) = message.downcast_mut() { return self.merge_wk_value(value); } - if let Some(value) = message_down_cast_mut(message) { + if let Some(value) = message.downcast_mut() { return self.merge_wk_any(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wrapper(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_bool_value(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_string_value(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_bytes_value(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wk_list_value(value); } - if let Some(value) = message_down_cast_mut::(message) { + if let Some(value) = message.downcast_mut::() { return self.merge_wk_struct(value); } diff --git a/protobuf/src/json/print.rs b/protobuf/src/json/print.rs index 4bdeea83c..910d82e84 100644 --- a/protobuf/src/json/print.rs +++ b/protobuf/src/json/print.rs @@ -10,7 +10,7 @@ use reflect::ReflectValueRef; use std::f32; use std::f64; use Message; -use core::message_down_cast_ref; + use well_known_types::Any; use well_known_types::BoolValue; @@ -401,37 +401,37 @@ impl Printer { } fn print_message(&mut self, message: &Message) -> PrintResult<()> { - if let Some(duration) = message_down_cast_ref::(message) { + if let Some(duration) = message.downcast_ref::() { self.print_printable(duration) - } else if let Some(timestamp) = message_down_cast_ref::(message) { + } else if let Some(timestamp) = message.downcast_ref::() { self.print_printable(timestamp) - } else if let Some(field_mask) = message_down_cast_ref::(message) { + } else if let Some(field_mask) = message.downcast_ref::() { self.print_printable(field_mask) - } else if let Some(any) = message_down_cast_ref::(message) { + } else if let Some(any) = message.downcast_ref::() { self.print_printable(any) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_printable(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_wrapper(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_printable(value) - } else if let Some(value) = message_down_cast_ref::(message) { + } else if let Some(value) = message.downcast_ref::() { self.print_printable(value) } else { self.print_regular_message(message) diff --git a/protobuf/src/reflect/accessor/map.rs b/protobuf/src/reflect/accessor/map.rs index cb4240f83..f35d345a9 100644 --- a/protobuf/src/reflect/accessor/map.rs +++ b/protobuf/src/reflect/accessor/map.rs @@ -1,10 +1,8 @@ use std::collections::HashMap; use std::hash::Hash; -use core::message_down_cast_ref; use Message; -use core::message_down_cast_mut; use reflect::accessor::AccessorKind; use reflect::accessor::FieldAccessor; use reflect::map::ReflectMapMut; @@ -47,7 +45,7 @@ where ::Value: Eq + Hash, { fn get_reflect<'a>(&self, m: &'a Message) -> ReflectMapRef<'a> { - let m = message_down_cast_ref(m).unwrap(); + let m = m.downcast_ref().unwrap(); let map = (self.get_field)(m); ReflectMapRef { map, @@ -57,7 +55,7 @@ where } fn mut_reflect<'a>(&self, m: &'a mut Message) -> ReflectMapMut<'a> { - let m = message_down_cast_mut(m).unwrap(); + let m = m.downcast_mut().unwrap(); let map = (self.mut_field)(m); ReflectMapMut { map, diff --git a/protobuf/src/reflect/accessor/repeated.rs b/protobuf/src/reflect/accessor/repeated.rs index db81d8e53..19d7c5d5a 100644 --- a/protobuf/src/reflect/accessor/repeated.rs +++ b/protobuf/src/reflect/accessor/repeated.rs @@ -3,8 +3,6 @@ use std::marker; use Message; use RepeatedField; -use core::message_down_cast_ref; -use core::message_down_cast_mut; use reflect::accessor::AccessorKind; use reflect::accessor::FieldAccessor; use reflect::repeated::ReflectRepeated; @@ -82,7 +80,7 @@ where V: ProtobufType, { fn get_reflect<'a>(&self, m: &'a Message) -> ReflectRepeatedRef<'a> { - let m = message_down_cast_ref(m).unwrap(); + let m = m.downcast_ref().unwrap(); let repeated = self.fns.get_field(m); ReflectRepeatedRef { repeated, @@ -91,7 +89,7 @@ where } fn mut_reflect<'a>(&self, m: &'a mut Message) -> ReflectRepeatedMut<'a> { - let m = message_down_cast_mut(m).unwrap(); + let m = m.downcast_mut().unwrap(); let repeated = self.fns.mut_field(m); ReflectRepeatedMut { repeated, diff --git a/protobuf/src/reflect/accessor/singular.rs b/protobuf/src/reflect/accessor/singular.rs index 802608061..bedb4b2de 100644 --- a/protobuf/src/reflect/accessor/singular.rs +++ b/protobuf/src/reflect/accessor/singular.rs @@ -1,5 +1,3 @@ -use core::message_down_cast_ref; -use core::message_down_cast_mut; use reflect::accessor::AccessorKind; use reflect::accessor::FieldAccessor; use reflect::runtime_types::RuntimeType; @@ -73,18 +71,18 @@ where } fn get_reflect<'a>(&self, m: &'a Message) -> Option> { - let m = message_down_cast_ref(m).unwrap(); + let m = m.downcast_ref().unwrap(); self.get_option_impl.get_reflect_impl(m) } fn get_singular_field_or_default<'a>(&self, m: &'a Message) -> ReflectValueRef<'a> { - let m = message_down_cast_ref(m).unwrap(); + let m = m.downcast_ref().unwrap(); self.get_or_default_impl .get_singular_field_or_default_impl(m) } fn set_singular_field(&self, m: &mut Message, value: ReflectValueBox) { - let m = message_down_cast_mut(m).unwrap(); + let m = m.downcast_mut().unwrap(); self.set_impl.set_singular_field(m, value) } } @@ -156,7 +154,6 @@ where O: OptionLike + Sync + Send + 'static, { fn get_reflect_impl<'a>(&self, m: &'a M) -> Option> { - let m = message_down_cast_ref(m).unwrap(); (self.get_field)(m).as_option_ref().map(V::as_ref) } }