Skip to content

Commit

Permalink
Merge to the latest
Browse files Browse the repository at this point in the history
  • Loading branch information
AgeManning committed Jan 15, 2024
2 parents 8fd5434 + 6187597 commit 2ecc61f
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 232 deletions.
184 changes: 89 additions & 95 deletions src/handler/mod.rs

Large diffs are not rendered by default.

21 changes: 8 additions & 13 deletions src/handler/nat_hole_punch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::net::SocketAddr;

use enr::NodeId;

use crate::{
node_info::NodeAddress, packet::MessageNonce, rpc::Notification, Enr, ProtocolIdentity,
node_info::NodeAddress,
packet::MessageNonce,
rpc::{RelayInitNotification, RelayMsgNotification},
Enr, ProtocolIdentity,
};

mod error;
Expand All @@ -27,26 +28,20 @@ pub trait HolePunchNat {

/// A RelayInit notification is received over discv5 indicating this node is the relay. Should
/// trigger sending a RelayMsg to the target.
async fn on_relay_init(
&mut self,
initr: Enr,
tgt: NodeId,
timed_out_nonce: MessageNonce,
) -> Result<(), Error>;
async fn on_relay_init(&mut self, relay_init: RelayInitNotification) -> Result<(), Error>;

/// A RelayMsg notification is received over discv5 indicating this node is the target. Should
/// trigger a WHOAREYOU to be sent to the initiator using the `nonce` in the RelayMsg.
async fn on_relay_msg(
async fn on_relay_msg<P: ProtocolIdentity>(
&mut self,
initr: Enr,
timed_out_nonce: MessageNonce,
relay_msg: RelayMsgNotification,
) -> Result<(), Error>;

/// Send a RELAYMSG notification.
async fn send_relay_msg_notif<P: ProtocolIdentity>(
&mut self,
tgt_enr: Enr,
relay_msg_notif: Notification,
relay_msg_notif: RelayMsgNotification,
) -> Result<(), Error>;

/// A hole punched for a peer closes. Should trigger an empty packet to be sent to the
Expand Down
148 changes: 129 additions & 19 deletions src/handler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::*;
use crate::{
handler::session::build_dummy_session,
packet::{DefaultProtocolId, PacketHeader, MAX_PACKET_SIZE},
packet::{DefaultProtocolId, PacketHeader, MAX_PACKET_SIZE, MESSAGE_NONCE_LENGTH},
return_if_ipv6_is_not_supported,
rpc::{Request, Response},
Discv5ConfigBuilder, IpMode,
Expand Down Expand Up @@ -94,7 +94,10 @@ async fn build_handler_with_listen_config<P: ProtocolIdentity>(
active_requests: ActiveRequests::new(config.request_timeout),
pending_requests: HashMap::new(),
filter_expected_responses,
sessions: LruTimeCache::new(config.session_timeout,Some(config.session_cache_capacity)),
sessions: LruTimeCache::new(
config.session_timeout,
Some(config.session_cache_capacity),
),
one_time_sessions: LruTimeCache::new(
Duration::from_secs(ONE_TIME_SESSION_TIMEOUT),
Some(ONE_TIME_SESSION_CACHE_CAPACITY),
Expand Down Expand Up @@ -486,7 +489,7 @@ async fn remove_one_time_session() {
}

#[tokio::test(flavor = "multi_thread")]
async fn relay() {
async fn nat_hole_punch_relay() {
init();

// Relay
Expand All @@ -497,23 +500,23 @@ async fn relay() {
let relay_node_id = handler.enr.read().node_id();

// Initiator
let initr_enr = {
let inr_enr = {
let key = CombinedKey::generate_secp256k1();
EnrBuilder::new("v4")
.ip4(Ipv4Addr::LOCALHOST)
.udp4(9011)
.build(&key)
.unwrap()
};
let initr_addr = initr_enr.udp4_socket().unwrap().into();
let initr_node_id = initr_enr.node_id();
let inr_addr = inr_enr.udp4_socket().unwrap().into();
let inr_node_id = inr_enr.node_id();

let initr_node_address = NodeAddress::new(initr_addr, initr_enr.node_id());
let initr_node_address = NodeAddress::new(inr_addr, inr_enr.node_id());
handler
.sessions
.insert(initr_node_address, build_dummy_session());

let initr_socket = UdpSocket::bind(initr_addr)
let inr_socket = UdpSocket::bind(inr_addr)
.await
.expect("should bind to initiator socket");

Expand Down Expand Up @@ -548,25 +551,25 @@ async fn relay() {
let mock_service_handle = tokio::spawn(async move {
let service_msg = rx.recv().await.expect("should receive service message");
match service_msg {
HandlerOut::FindHolePunchEnr(_tgt_node_id, relay_msg_notif) => tx
.send(HandlerIn::HolePunchEnr(tgt_enr_clone, relay_msg_notif))
HandlerOut::FindHolePunchEnr(relay_init) => tx
.send(HandlerIn::HolePunchEnr(tgt_enr_clone, relay_init))
.expect("should send message to handler"),
_ => panic!("service message should be 'find hole punch enr'"),
}
});

// Initiator handle
let relay_init_notif =
Notification::RelayInit(initr_enr.clone(), tgt_node_id, MessageNonce::default());
RelayInitNotification::new(inr_enr.clone(), tgt_node_id, MessageNonce::default());

let initr_handle = tokio::spawn(async move {
let inr_handle = tokio::spawn(async move {
let mut session = build_dummy_session();
let packet = session
.encrypt_session_message::<DefaultProtocolId>(initr_node_id, &relay_init_notif.encode())
.encrypt_session_message::<DefaultProtocolId>(inr_node_id, &relay_init_notif.encode())
.expect("should encrypt notification");
let encoded_packet = packet.encode::<DefaultProtocolId>(&relay_node_id);

initr_socket
inr_socket
.send_to(&encoded_packet, relay_addr)
.await
.expect("should relay init notification to relay")
Expand All @@ -587,10 +590,10 @@ async fn relay() {
});

// Join all handles
let (initr_res, relay_res, tgt_res, mock_service_res) =
tokio::join!(initr_handle, relay_handle, tgt_handle, mock_service_handle);
let (inr_res, relay_res, tgt_res, mock_service_res) =
tokio::join!(inr_handle, relay_handle, tgt_handle, mock_service_handle);

initr_res.unwrap();
inr_res.unwrap();
relay_res.unwrap();
mock_service_res.unwrap();

Expand Down Expand Up @@ -620,9 +623,116 @@ async fn relay() {
.decrypt_message(message_nonce, &message, &aad)
.expect("should decrypt message");
match Message::decode(&decrypted_message).expect("should decode message") {
Message::Notification(Notification::RelayMsg(enr, _nonce)) => {
assert_eq!(initr_enr, enr)
Message::RelayMsgNotification(relay_msg) => {
let (enr, _) = relay_msg.into();
assert_eq!(inr_enr, enr)
}
_ => panic!("message should decode to a relay msg notification"),
}
}

#[tokio::test(flavor = "multi_thread")]
async fn nat_hole_punch_target() {
init();

// Target
let listen_config = ListenConfig::default().with_ipv4(Ipv4Addr::LOCALHOST, 9902);
let (mut handler, mock_service) =
build_handler_with_listen_config::<DefaultProtocolId>(listen_config).await;
let tgt_addr = handler.enr.read().udp4_socket().unwrap().into();
let tgt_node_id = handler.enr.read().node_id();
handler.nat_utils.is_behind_nat = Some(true);

// Relay
let relay_enr = {
let key = CombinedKey::generate_secp256k1();
EnrBuilder::new("v4")
.ip4(Ipv4Addr::LOCALHOST)
.udp4(9022)
.build(&key)
.unwrap()
};
let relay_addr = relay_enr.udp4_socket().unwrap().into();
let relay_node_id = relay_enr.node_id();

let relay_node_address = NodeAddress::new(relay_addr, relay_node_id);
handler
.sessions
.insert(relay_node_address, build_dummy_session());

let relay_socket = UdpSocket::bind(relay_addr)
.await
.expect("should bind to target socket");

// Initiator
let inr_enr = {
let key = CombinedKey::generate_secp256k1();
EnrBuilder::new("v4")
.ip4(Ipv4Addr::LOCALHOST)
.udp4(9021)
.build(&key)
.unwrap()
};
let inr_addr = inr_enr.udp4_socket().unwrap();
let inr_node_id = inr_enr.node_id();
let inr_nonce: MessageNonce = [1; MESSAGE_NONCE_LENGTH];

let inr_socket = UdpSocket::bind(inr_addr)
.await
.expect("should bind to initiator socket");

// Target handle
let tgt_handle = tokio::spawn(async move { handler.start::<DefaultProtocolId>().await });

// Relay handle
let relay_msg_notif = RelayMsgNotification::new(inr_enr.clone(), inr_nonce);

let relay_handle = tokio::spawn(async move {
let mut session = build_dummy_session();
let packet = session
.encrypt_session_message::<DefaultProtocolId>(relay_node_id, &relay_msg_notif.encode())
.expect("should encrypt notification");
let encoded_packet = packet.encode::<DefaultProtocolId>(&tgt_node_id);

relay_socket
.send_to(&encoded_packet, tgt_addr)
.await
.expect("should relay init notification to relay")
});

// Initiator handle
let target_exit = mock_service.exit_tx;
let inr_handle = tokio::spawn(async move {
let mut buffer = [0; MAX_PACKET_SIZE];
let res = inr_socket
.recv_from(&mut buffer)
.await
.expect("should read bytes from socket");

drop(target_exit);

(res, buffer)
});

// Join all handles
let (tgt_res, relay_res, inr_res) = tokio::join!(tgt_handle, relay_handle, inr_handle);

tgt_res.unwrap();
relay_res.unwrap();

let ((length, src), buffer) = inr_res.unwrap();

assert_eq!(src, tgt_addr);

let (packet, _aad) = Packet::decode::<DefaultProtocolId>(&inr_node_id, &buffer[..length])
.expect("should decode packet");
let Packet { header, .. } = packet;
let PacketHeader {
kind,
message_nonce,
..
} = header;

assert!(kind.is_whoareyou());
assert_eq!(message_nonce, inr_nonce)
}
4 changes: 2 additions & 2 deletions src/kbucket/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ where
PendingEntry(EntryRef { bucket, key })
}

/// Returns the value associated with the key.
pub fn value(&mut self) -> &mut TVal {
/// Returns mutable access value associated with the key.
pub fn value_mut(&mut self) -> &mut TVal {
self.0
.bucket
.pending_mut()
Expand Down
42 changes: 26 additions & 16 deletions src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod notification;
mod request;
mod response;

pub use notification::Notification;
pub use notification::{RelayInitNotification, RelayMsgNotification};
pub use request::{Request, RequestBody, RequestId};
pub use response::{Response, ResponseBody};

Expand Down Expand Up @@ -60,12 +60,19 @@ pub enum Message {
/// A request, which contains its [`RequestId`].
#[display(fmt = "{_0}")]
Request(Request),

/// A Response, which contains the [`RequestId`] of its associated request.
#[display(fmt = "{_0}")]
Response(Response),
/// A unicast notification.

/// Unicast notifications.
///
/// A [`RelayInitNotification`].
#[display(fmt = "{_0}")]
Notification(Notification),
RelayInitNotification(RelayInitNotification),
/// A [`RelayMsgNotification`].
#[display(fmt = "{_0}")]
RelayMsgNotification(RelayMsgNotification),
}

#[allow(dead_code)]
Expand All @@ -74,7 +81,8 @@ impl Message {
match self {
Self::Request(request) => request.encode(),
Self::Response(response) => response.encode(),
Self::Notification(notif) => notif.encode(),
Self::RelayInitNotification(notif) => notif.encode(),
Self::RelayMsgNotification(notif) => notif.encode(),
}
}

Expand All @@ -93,17 +101,17 @@ impl Message {
MessageType::Pong | MessageType::Nodes | MessageType::TalkResp => {
Ok(Response::decode(msg_type, &rlp)?.into())
}
MessageType::RelayInit | MessageType::RelayMsg => {
Ok(Notification::decode(msg_type, &rlp)?.into())
}
MessageType::RelayInit => Ok(RelayInitNotification::decode(msg_type, &rlp)?.into()),
MessageType::RelayMsg => Ok(RelayMsgNotification::decode(msg_type, &rlp)?.into()),
}
}

pub fn msg_type(&self) -> String {
match self {
Self::Notification(n) => format!("notification type {}", n.msg_type()),
Self::Request(r) => format!("request type {}", r.msg_type()),
Self::Response(r) => format!("response type {}", r.msg_type()),
Self::RelayInitNotification(n) => format!("notification type {}", n.msg_type()),
Self::RelayMsgNotification(n) => format!("notification type {}", n.msg_type()),
}
}
}
Expand Down Expand Up @@ -413,12 +421,13 @@ mod tests {
let mut nonce = [0u8; MESSAGE_NONCE_LENGTH];
nonce[MESSAGE_NONCE_LENGTH - nonce_bytes.len()..].copy_from_slice(&nonce_bytes);

let notif = Message::Notification(Notification::RelayInit(inr_enr, tgt_node_id, nonce));
let notif = RelayInitNotification::new(inr_enr, tgt_node_id, nonce);
let msg = Message::RelayInitNotification(notif);

let encoded_notif = notif.clone().encode();
let decoded_notif = Message::decode(&encoded_notif).expect("Should decode");
let encoded_msg = msg.clone().encode();
let decoded_msg = Message::decode(&encoded_msg).expect("Should decode");

assert_eq!(notif, decoded_notif);
assert_eq!(msg, decoded_msg);
}

#[test]
Expand All @@ -432,11 +441,12 @@ mod tests {
let mut nonce = [0u8; MESSAGE_NONCE_LENGTH];
nonce[MESSAGE_NONCE_LENGTH - nonce_bytes.len()..].copy_from_slice(&nonce_bytes);

let notif = Message::Notification(Notification::RelayMsg(inr_enr, nonce));
let notif = RelayMsgNotification::new(inr_enr, nonce);
let msg = Message::RelayMsgNotification(notif);

let encoded_notif = notif.clone().encode();
let decoded_notif = Message::decode(&encoded_notif).expect("Should decode");
let encoded_msg = msg.clone().encode();
let decoded_msg = Message::decode(&encoded_msg).expect("Should decode");

assert_eq!(notif, decoded_notif);
assert_eq!(msg, decoded_msg);
}
}
Loading

0 comments on commit 2ecc61f

Please sign in to comment.