Skip to content

Commit

Permalink
add recursion limit
Browse files Browse the repository at this point in the history
  • Loading branch information
BusyJay authored and stepancheg committed Oct 31, 2017
1 parent df87ecf commit 51f50fb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 6 deletions.
30 changes: 30 additions & 0 deletions protobuf-test/src/common/v2/test_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ fn test4() {
test_serialize_deserialize("22 06 03 8E 02 9E A7 05", &test4);
}

#[test]
fn test_recursion_limit() {
let mut test = TestRecursion::new();
for _ in 0..10 {
let mut t = TestRecursion::new();
t.mut_children().push(test);
test = t;
}

let bytes = test.write_to_bytes().unwrap();
let cases = vec![
(None, false),
(Some(9), true),
(Some(10), false),
];

for (limit, has_err) in cases {
let mut is = CodedInputStream::from_bytes(&bytes);
if let Some(limit) = limit {
is.set_recursion_limit(limit);
}
let mut t = TestRecursion::new();
let res = t.merge_from(&mut is);
assert_eq!(res.is_err(), has_err, "limit: {:?}", limit);
if !has_err {
assert_eq!(t, test, "limit: {:?}", limit);
}
}
}

#[test]
fn test_read_unpacked_expect_packed() {
let mut test_packed_unpacked = TestPackedUnpacked::new();
Expand Down
4 changes: 4 additions & 0 deletions protobuf-test/src/common/v2/test_basic_pb.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ message Test3 {
required Test1 c = 3;
}

message TestRecursion {
repeated TestRecursion children = 1;
}

message Test4 {
repeated int32 d = 4 [packed=true];
}
Expand Down
2 changes: 2 additions & 0 deletions protobuf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum WireError {
IncorrectVarint,
Utf8Error,
InvalidEnumValue(i32),
OverRecursionLimit,
Other,
}

Expand Down Expand Up @@ -55,6 +56,7 @@ impl Error for ProtobufError {
WireError::IncorrectVarint => "incorrect varint",
WireError::IncompleteMap => "incomplete map",
WireError::UnexpectedEof => "unexpected EOF",
WireError::OverRecursionLimit => "over recursion limit",
WireError::Other => "other error",
}
}
Expand Down
10 changes: 8 additions & 2 deletions protobuf/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,11 @@ pub fn read_repeated_message_into<M : Message + Default>(
) -> ProtobufResult<()> {
match wire_type {
WireTypeLengthDelimited => {
is.incr_recursion()?;
let tmp = target.push_default();
is.merge_message(tmp)
let res = is.merge_message(tmp);
is.decr_recursion();
res
}
_ => Err(unexpected_wire_type(wire_type)),
}
Expand All @@ -652,8 +655,11 @@ pub fn read_singular_message_into<M : Message + Default>(
) -> ProtobufResult<()> {
match wire_type {
WireTypeLengthDelimited => {
is.incr_recursion()?;
let tmp = target.set_default();
is.merge_message(tmp)
let res = is.merge_message(tmp);
is.decr_recursion();
res
}
_ => Err(unexpected_wire_type(wire_type)),
}
Expand Down
40 changes: 36 additions & 4 deletions protobuf/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ use buf_read_iter::BufReadIter;
// `CodedOutputStream` wraps `BufWriter`, it often skips double buffering.
const OUTPUT_STREAM_BUFFER_SIZE: usize = 8 * 1024;

// Default recursion level limit. 100 is the default value of C++'s implementation.
const DEFAULT_RECURSION_LIMIT: u32 = 100;


pub mod wire_format {
// TODO: temporary
Expand Down Expand Up @@ -121,24 +124,53 @@ pub mod wire_format {

pub struct CodedInputStream<'a> {
source: BufReadIter<'a>,
recursion_level: u32,
recursion_limit: u32,
}

impl<'a> CodedInputStream<'a> {
pub fn new(read: &'a mut Read) -> CodedInputStream<'a> {
CodedInputStream { source: BufReadIter::from_read(read) }
CodedInputStream::from_buf_read_iter(BufReadIter::from_read(read))
}

pub fn from_buffered_reader(buf_read: &'a mut BufRead) -> CodedInputStream<'a> {
CodedInputStream { source: BufReadIter::from_buf_read(buf_read) }
CodedInputStream::from_buf_read_iter(BufReadIter::from_buf_read(buf_read))
}

pub fn from_bytes(bytes: &'a [u8]) -> CodedInputStream<'a> {
CodedInputStream { source: BufReadIter::from_byte_slice(bytes) }
CodedInputStream::from_buf_read_iter(BufReadIter::from_byte_slice(bytes))
}

#[cfg(feature = "bytes")]
pub fn from_carllerche_bytes(bytes: &'a Bytes) -> CodedInputStream<'a> {
CodedInputStream { source: BufReadIter::from_bytes(bytes) }
CodedInputStream::from_buf_read_iter(BufReadIter::from_bytes(bytes))
}

fn from_buf_read_iter(source: BufReadIter<'a>) -> CodedInputStream<'a> {
CodedInputStream {
source: source,
recursion_level: 0,
recursion_limit: DEFAULT_RECURSION_LIMIT,
}
}

/// Set the recursion limit.
pub fn set_recursion_limit(&mut self, limit: u32) {
self.recursion_limit = limit;
}

#[inline]
pub(crate) fn incr_recursion(&mut self) -> ProtobufResult<()> {
if self.recursion_level >= self.recursion_limit {
return Err(ProtobufError::WireError(WireError::OverRecursionLimit));
}
self.recursion_level += 1;
Ok(())
}

#[inline]
pub(crate) fn decr_recursion(&mut self) {
self.recursion_level -= 1;
}

pub fn pos(&self) -> u64 {
Expand Down

0 comments on commit 51f50fb

Please sign in to comment.