Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add within support to allow iterating over IpNetworks (cidrs) #50

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ name ="maxminddb"
path = "src/maxminddb/lib.rs"

[dependencies]
ipnetwork = "0.18.0"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
memchr = "2.4"
Expand Down
44 changes: 44 additions & 0 deletions examples/within.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use ipnetwork::IpNetwork;
use maxminddb::{geoip2, Within};

fn main() -> Result<(), String> {
let mut args = std::env::args().skip(1);
let reader = maxminddb::Reader::open_readfile(
args.next()
.ok_or("First argument must be the path to the IP database")?,
)
.unwrap();
let cidr: String = args
.next()
.ok_or("Second argument must be the IP address and mask in CIDR notation, e.g. 0.0.0.0/0 or ::/0")?
.parse()
.unwrap();
let ip_net = if cidr.contains(":") {
IpNetwork::V6(cidr.parse().unwrap())
} else {
IpNetwork::V4(cidr.parse().unwrap())
};

let mut n = 0;
let mut iter: Within<geoip2::City, _> = reader.within(ip_net).map_err(|e| e.to_string())?;
while let Some(next) = iter.next() {
let item = next.map_err(|e| e.to_string())?;
let continent = item.info.continent.and_then(|c| c.code).unwrap_or("");
let country = item.info.country.and_then(|c| c.iso_code).unwrap_or("");
let city = match item.info.city.and_then(|c| c.names) {
Some(names) => names.get("en").unwrap_or(&""),
None => "",
};
if !city.is_empty() {
println!("{} {}-{}-{}", item.ip_net, continent, country, city);
} else if !country.is_empty() {
println!("{} {}-{}", item.ip_net, continent, country);
} else if !continent.is_empty() {
println!("{} {}", item.ip_net, continent);
}
n += 1;
}
eprintln!("processed {} items", n);

Ok(())
}
205 changes: 204 additions & 1 deletion src/maxminddb/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
use std::collections::BTreeMap;
use std::fmt::{self, Display, Formatter};
use std::io;
use std::net::IpAddr;
use std::marker::PhantomData;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::Path;

use ipnetwork::IpNetwork;
use serde::{de, Deserialize};

#[cfg(feature = "mmap")]
Expand All @@ -22,6 +24,7 @@ pub enum MaxMindDBError {
IoError(String),
MapError(String),
DecodingError(String),
InvalidNetworkError(String),
}

impl From<io::Error> for MaxMindDBError {
Expand All @@ -43,6 +46,9 @@ impl Display for MaxMindDBError {
MaxMindDBError::IoError(msg) => write!(fmt, "IoError: {}", msg)?,
MaxMindDBError::MapError(msg) => write!(fmt, "MapError: {}", msg)?,
MaxMindDBError::DecodingError(msg) => write!(fmt, "DecodingError: {}", msg)?,
MaxMindDBError::InvalidNetworkError(msg) => {
write!(fmt, "InvalidNetworkError: {}", msg)?
}
}
Ok(())
}
Expand Down Expand Up @@ -70,7 +76,90 @@ pub struct Metadata {
pub record_size: u16,
}

#[derive(Debug)]
struct WithinNode {
node: usize,
ip_bytes: Vec<u8>,
prefix_len: usize,
}

#[derive(Debug)]
pub struct Within<'de, T: Deserialize<'de>, S: AsRef<[u8]>> {
reader: &'de Reader<S>,
node_count: usize,
stack: Vec<WithinNode>,
phantom: PhantomData<&'de T>,
}

#[derive(Debug)]
pub struct WithinItem<T> {
pub ip_net: IpNetwork,
pub info: T,
}

impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
type Item = Result<WithinItem<T>, MaxMindDBError>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fairly new to rust (in case it's not obvious) and I'm not finding a clean way to go about returning errors that happen during the walking of nodes. My searching for idiomatic patterns didn't turn anything up, but it's a tough thing to search for.

This results in some pretty ugly code below to map errors in the helper functions into Some(Err(e)). That seems bad enough, but currently have have things in match Ok/Err conditionals that make it tough to read. Things like .map_err won't let me wrap with Some so I don't see a cleaner way. Seems like either a shortcoming of iterators or I'm just going about this all wrong 😁

Input/thoughts/suggestions welcome.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there may be a crate, https://docs.rs/fallible-iterator/0.2.0/fallible_iterator/ that calls out the situation.

Seems like using it might be the cleaner way to go, but will defer to more experienced opinions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a FallibleIterator based solution and it cleans up pretty much all of the ugliness in error handling here. It's not perfect since you can't use for x in y, and have to go with while let Some(x) = y.next(), but that's a limitation of rust and seems like a reasonable tradeoff for cleaner error handling.

Changes for that can be seen with https:/ross/maxminddb-rust/compare/within-support...ross:within-support-fallibleiterator?expand=1. Thoughts welcome.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it does clean up the code somewhat, I think I'd prefer not having the dependency, especially given that the fallible iterator crate hasn't seen an update in two years and doesn't seem that widely used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it hasn't seen any updates b/c it does what it needs to do. Similar to how I wrote requests-futures 8 years ago and haven't really touched it meaningfully in most of that time.

As for not wanting the dep, I don't have a response to that. The code to implement the non-fallible iterator is ugly as hell, but using it isn't THAT bad. Kind of seems like a shortcoming of Rust itself that iterators can't fail.

I'll stick this back on my TODO list to pick back up and see if I can get it over the line.


fn next(&mut self) -> Option<Self::Item> {
while !self.stack.is_empty() {
let current = self.stack.pop().unwrap();
let bit_count = current.ip_bytes.len() * 8;

if current.node > self.node_count {
// This is a data node, emit it and we're done (until the following next call)
let ip_net =
match bytes_and_prefix_to_net(&current.ip_bytes, current.prefix_len as u8) {
Ok(ip_net) => ip_net,
Err(e) => return Some(Err(e)),
};
// TODO: should this block become a helper method on reader?
let rec = match self.reader.resolve_data_pointer(current.node) {
Ok(rec) => rec,
Err(e) => return Some(Err(e)),
};
let mut decoder = decoder::Decoder::new(
&self.reader.buf.as_ref()[self.reader.pointer_base..],
rec,
);
return match T::deserialize(&mut decoder) {
Ok(info) => Some(Ok(WithinItem { ip_net, info })),
Err(e) => Some(Err(e)),
};
} else if current.node == self.node_count {
// Dead end, nothing to do
} else {
// In order traversal of our children
// right/1-bit
let mut right_ip_bytes = current.ip_bytes.clone();
right_ip_bytes[current.prefix_len >> 3] |=
1 << ((bit_count - current.prefix_len - 1) % 8);
let node = match self.reader.read_node(current.node, 1) {
Ok(node) => node,
Err(e) => return Some(Err(e)),
};
self.stack.push(WithinNode {
node,
ip_bytes: right_ip_bytes,
prefix_len: current.prefix_len + 1,
});
// left/0-bit
let node = match self.reader.read_node(current.node, 0) {
Ok(node) => node,
Err(e) => return Some(Err(e)),
};
self.stack.push(WithinNode {
node,
ip_bytes: current.ip_bytes.clone(),
prefix_len: current.prefix_len + 1,
});
}
}
None
}
}

/// A reader for the MaxMind DB format. The lifetime `'data` is tied to the lifetime of the underlying buffer holding the contents of the database file.
#[derive(Debug)]
pub struct Reader<S: AsRef<[u8]>> {
buf: S,
pub metadata: Metadata,
Expand Down Expand Up @@ -172,6 +261,72 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
T::deserialize(&mut decoder)
}

/// Iterate over blocks of IP networks in the opened MaxMind DB
///
/// Example:
///
/// ```
/// use ipnetwork::IpNetwork;
/// use maxminddb::{geoip2, Within};
///
/// let reader = maxminddb::Reader::open_readfile("test-data/test-data/GeoIP2-City-Test.mmdb").unwrap();
///
/// let ip_net = IpNetwork::V6("::/0".parse().unwrap());
/// let mut iter: Within<geoip2::City, _> = reader.within(ip_net).unwrap();
/// while let Some(next) = iter.next() {
/// let item = next.unwrap();
/// println!("ip_net={}, city={:?}", item.ip_net, item.info);
/// }
/// ```
pub fn within<T>(&'de self, cidr: IpNetwork) -> Result<Within<T, S>, MaxMindDBError>
where
T: Deserialize<'de>,
{
let ip_address = cidr.network();
let prefix_len = cidr.prefix() as usize;
let ip_bytes = ip_to_bytes(ip_address);
let bit_count = ip_bytes.len() * 8;

let mut node = self.start_node(bit_count);
let node_count = self.metadata.node_count as usize;

let mut stack: Vec<WithinNode> = Vec::with_capacity(bit_count - prefix_len);

// Traverse down the tree to the level that matches the cidr mark
let mut i = 0_usize;
while i < prefix_len {
let bit = 1 & (ip_bytes[i >> 3] >> 7 - (i % 8)) as usize;
node = self.read_node(node, bit)?;
if node >= node_count {
// We've hit a dead end before we exhausted our prefix
break;
}

i += 1;
}

if node < node_count {
// Ok, now anything that's below node in the tree is "within", start with the node we
// traversed to as our to be processed stack.
stack.push(WithinNode {
node,
ip_bytes,
prefix_len,
});
}
// else the stack will be empty and we'll be returning an iterator that visits nothing,
// which makes sense.

let within: Within<T, S> = Within {
reader: self,
node_count,
stack,
phantom: PhantomData,
};

Ok(within)
}

fn find_address_in_tree(&self, ip_address: &[u8]) -> Result<usize, MaxMindDBError> {
let bit_count = ip_address.len() * 8;
let mut node = self.start_node(bit_count);
Expand Down Expand Up @@ -284,6 +439,54 @@ fn ip_to_bytes(address: IpAddr) -> Vec<u8> {
}
}

fn bytes_and_prefix_to_net(bytes: &Vec<u8>, prefix: u8) -> Result<IpNetwork, MaxMindDBError> {
let (ip, pre) = match bytes.len() {
4 => (
IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3])),
prefix,
),
16 => {
if bytes[0] == 0
&& bytes[1] == 0
&& bytes[2] == 0
&& bytes[3] == 0
&& bytes[4] == 0
&& bytes[5] == 0
&& bytes[6] == 0
&& bytes[7] == 0
&& bytes[8] == 0
&& bytes[9] == 0
&& bytes[10] == 0
&& bytes[11] == 0
{
// It's actually v4, but in v6 form, convert would be nice if ipnetwork had this
// logic.
(
IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15])),
prefix - 96,
)
} else {
let a = (bytes[0] as u16) << 8 | bytes[1] as u16;
let b = (bytes[2] as u16) << 8 | bytes[3] as u16;
let c = (bytes[4] as u16) << 8 | bytes[5] as u16;
let d = (bytes[6] as u16) << 8 | bytes[7] as u16;
let e = (bytes[8] as u16) << 8 | bytes[9] as u16;
let f = (bytes[10] as u16) << 8 | bytes[11] as u16;
let g = (bytes[12] as u16) << 8 | bytes[13] as u16;
let h = (bytes[14] as u16) << 8 | bytes[15] as u16;
(IpAddr::V6(Ipv6Addr::new(a, b, c, d, e, f, g, h)), prefix)
}
}
// This should never happen
_ => {
return Err(MaxMindDBError::InvalidNetworkError(
"invalid address".to_owned(),
))
}
};
IpNetwork::new(ip, pre).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string()))
}

fn find_metadata_start(buf: &[u8]) -> Result<usize, MaxMindDBError> {
const METADATA_START_MARKER: &[u8] = b"\xab\xcd\xefMaxMind.com";

Expand Down
71 changes: 71 additions & 0 deletions src/maxminddb/reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,77 @@ fn test_lookup_asn() {
assert_eq!(asn.autonomous_system_organization, Some("Telstra Pty Ltd"));
}

#[test]
fn test_within_city() {
use super::geoip2::City;
use super::Within;
use ipnetwork::IpNetwork;

let _ = env_logger::try_init();

let filename = "test-data/test-data/GeoIP2-City-Test.mmdb";

let reader = Reader::open_readfile(filename).unwrap();

let ip_net = IpNetwork::V6("::/0".parse().unwrap());

let mut iter: Within<City, _> = reader.within(ip_net).unwrap();

// Make sure the first is what we expect it to be
let item = iter.next().unwrap().unwrap();
assert_eq!(
item.ip_net,
IpNetwork::V4("2.125.160.216/29".parse().unwrap())
);
assert_eq!(item.info.continent.unwrap().code, Some("EU"));
assert_eq!(item.info.country.unwrap().iso_code, Some("GB"));

let mut n = 1;
while let Some(_) = iter.next() {
n += 1;
}

// Make sure we had the expected number
assert_eq!(n, 273);

// A second run through this time a specific network
let specific = IpNetwork::V4("81.2.69.0/24".parse().unwrap());
let mut iter: Within<City, _> = reader.within(specific).unwrap();
// Make sure we have the expected blocks/info
let mut expected = vec![
// Note: reversed so we can use pop
IpNetwork::V4("81.2.69.192/28".parse().unwrap()),
IpNetwork::V4("81.2.69.160/27".parse().unwrap()),
IpNetwork::V4("81.2.69.144/28".parse().unwrap()),
IpNetwork::V4("81.2.69.142/31".parse().unwrap()),
];
while expected.len() > 0 {
let e = expected.pop().unwrap();
let item = iter.next().unwrap().unwrap();
assert_eq!(item.ip_net, e);
}
}

#[test]
fn test_within_broken_database() {
use super::geoip2::City;
use ipnetwork::IpNetwork;

let r = Reader::open_readfile("test-data/test-data/GeoIP2-City-Test-Broken-Double-Format.mmdb")
.ok()
.unwrap();

let ip_net = IpNetwork::V6("::/0".parse().unwrap());
let mut iter = r.within::<City>(ip_net).unwrap();
match iter.next().unwrap() {
Err(e) => assert_eq!(
e,
MaxMindDBError::InvalidDatabaseError("double of size 7".to_string())
),
Ok(_) => panic!("Error expected"),
};
}

fn check_metadata<T: AsRef<[u8]>>(reader: &Reader<T>, ip_version: usize, record_size: usize) {
let metadata = &reader.metadata;

Expand Down