diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 44fd9dc7d..5fbe6ec54 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: - name: Get latest version of stable rust run: rustup update stable - name: Lint code for quality and style with Clippy - run: cargo clippy --workspace --tests --all-features -- -D warnings + run: cargo clippy --workspace --tests --all-features -- -D warnings -A clippy::assertions_on_constants release-tests-ubuntu: runs-on: ubuntu-latest needs: cargo-fmt @@ -50,7 +50,7 @@ jobs: - name: Get latest version of stable rust run: rustup update stable - name: Check rustdoc links - run: RUSTDOCFLAGS="--deny broken_intra_doc_links" cargo doc --verbose --workspace --no-deps --document-private-items + run: RUSTDOCFLAGS="--deny rustdoc::broken_intra_doc_links" cargo doc --verbose --workspace --no-deps --document-private-items cargo-udeps: name: cargo-udeps runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 23e29aa3e..351c7d62d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,10 +34,11 @@ aes = { version = "0.7.5", features = ["ctr"] } aes-gcm = "0.9.4" tracing = { version = "0.1.29", features = ["log"] } tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } -lru = "0.7.1" +lru = {version = "0.7.1", default-features = false } hashlink = "0.7.0" delay_map = "0.3.0" more-asserts = "0.2.2" +derive_more = { version = "0.99.17", default-features = false, features = ["from", "display", "deref", "deref_mut"] } [dev-dependencies] rand_07 = { package = "rand", version = "0.7" } diff --git a/src/config.rs b/src/config.rs index 3d4153229..0d3fafa3b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,14 @@ //! A set of configuration parameters to tune the discovery protocol. use crate::{ - kbucket::MAX_NODES_PER_BUCKET, socket::ListenConfig, Enr, Executor, PermitBanList, RateLimiter, + kbucket::MAX_NODES_PER_BUCKET, Enr, Executor, ListenConfig, PermitBanList, RateLimiter, RateLimiterBuilder, }; -use std::time::Duration; + +/// The minimum number of unreachable Sessions a node must allow. This enables the network to +/// boostrap. +const MIN_SESSIONS_UNREACHABLE_ENR: usize = 10; + +use std::{ops::RangeInclusive, time::Duration}; /// Configuration parameters that define the performance of the discovery network. #[derive(Clone)] @@ -96,6 +101,15 @@ pub struct Discv5Config { /// timing support. By default, the executor that created the discv5 struct will be used. pub executor: Option>, + /// The max limit for peers with unreachable ENRs. Benevolent examples of such peers are peers + /// that are discovering their externally reachable socket, nodes must assist at least one + /// such peer in discovering their reachable socket via ip voting, and peers behind symmetric + /// NAT. Default is no limit. Minimum is 10. + pub unreachable_enr_limit: Option, + + /// The unused port range to try and bind to when testing if this node is behind NAT based on + /// observed address reported at runtime by peers. + pub unused_port_range: Option>, /// Configuration for the sockets to listen on. pub listen_config: ListenConfig, } @@ -142,6 +156,8 @@ impl Discv5ConfigBuilder { permit_ban_list: PermitBanList::default(), ban_duration: Some(Duration::from_secs(3600)), // 1 hour executor: None, + unreachable_enr_limit: None, + unused_port_range: None, listen_config, }; @@ -302,6 +318,23 @@ impl Discv5ConfigBuilder { self } + /// Sets the maximum number of sessions with peers with unreachable ENRs to allow. Minimum is 1 + /// peer. Default is no limit. + pub fn unreachable_enr_limit(&mut self, peer_limit: Option) -> &mut Self { + self.config.unreachable_enr_limit = peer_limit; + self + } + + /// Sets the unused port range for testing if node is behind a NAT. Default is the range + /// covering user and dynamic ports. + pub fn unused_port_range( + &mut self, + unused_port_range: Option>, + ) -> &mut Self { + self.config.unused_port_range = unused_port_range; + self + } + pub fn build(&mut self) -> Discv5Config { // If an executor is not provided, assume a current tokio runtime is running. if self.config.executor.is_none() { @@ -309,6 +342,9 @@ impl Discv5ConfigBuilder { }; assert!(self.config.incoming_bucket_limit <= MAX_NODES_PER_BUCKET); + if let Some(limit) = self.config.unreachable_enr_limit { + assert!(limit >= MIN_SESSIONS_UNREACHABLE_ENR); + } self.config.clone() } diff --git a/src/error.rs b/src/error.rs index d35ba90d4..7b9198d8c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,15 +1,24 @@ -use crate::{handler::Challenge, node_info::NonContactable}; +use crate::{ + handler::Challenge, + node_info::{NodeAddress, NonContactable}, +}; +use derive_more::From; use rlp::DecoderError; use std::fmt; -#[derive(Debug)] +#[derive(Debug, From)] /// A general error that is used throughout the Discv5 library. pub enum Discv5Error { + /// An invalid message type was received. + InvalidMessage, /// An invalid ENR was received. InvalidEnr, + /// The limit for sessions with peers that have an unreachable ENR is reached. + LimitSessionsUnreachableEnr, /// The public key type is known. UnknownPublicKey, /// The ENR key used is not supported. + #[from(ignore)] KeyTypeNotSupported(&'static str), /// Failed to derive an ephemeral public key. KeyDerivationFailed, @@ -27,25 +36,48 @@ pub enum Discv5Error { ServiceAlreadyStarted, /// A session could not be established with the remote. SessionNotEstablished, + /// A session to the given peer is already established. + SessionAlreadyEstablished(NodeAddress), /// An RLP decoding error occurred. RLPError(DecoderError), /// Failed to encrypt a message. + #[from(ignore)] EncryptionFail(String), /// Failed to decrypt a message. + #[from(ignore)] DecryptionFailed(String), /// The custom error has occurred. + #[from(ignore)] Custom(&'static str), /// A generic dynamic error occurred. + #[from(ignore)] Error(String), /// An IO error occurred. Io(std::io::Error), } -impl From for Discv5Error { - fn from(err: std::io::Error) -> Discv5Error { - Discv5Error::Io(err) - } +/// An error occurred whilst attempting to hole punch NAT. +#[derive(Debug)] +pub enum NatError { + /// Initiator error. + Initiator(Discv5Error), + /// Relayer error. + Relay(Discv5Error), + /// Target error. + Target(Discv5Error), +} + +macro_rules! impl_from_variant { + ($(<$($generic: ident,)+>)*, $from_type: ty, $to_type: ty, $variant: path) => { + impl$(<$($generic,)+>)* From<$from_type> for $to_type { + fn from(_e: $from_type) -> Self { + $variant + } + } + }; } +impl_from_variant!(, tokio::sync::mpsc::error::SendError, Discv5Error, Self::ServiceChannelClosed); +impl_from_variant!(, NonContactable, Discv5Error, Self::InvalidEnr); #[derive(Debug, Clone, PartialEq, Eq)] /// Types of packet errors. @@ -111,6 +143,9 @@ pub enum RequestError { InvalidMultiaddr(&'static str), /// Failure generating random numbers during request. EntropyFailure(&'static str), + /// Malicious peer tried to initiate nat hole punching for another peer. todo(emhane): this is + /// notification error. + MaliciousRelayInit, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 625582402..74a54f348 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -29,11 +29,14 @@ use crate::{ config::Discv5Config, discv5::PERMIT_BAN_LIST, - error::{Discv5Error, RequestError}, + error::{Discv5Error, NatError, RequestError}, packet::{ChallengeData, IdNonce, MessageNonce, Packet, PacketKind, ProtocolIdentity}, - rpc::{Message, Request, RequestBody, RequestId, Response, ResponseBody}, + rpc::{ + Message, Payload, RelayInitNotification, RelayMsgNotification, Request, RequestBody, + RequestId, Response, ResponseBody, + }, socket, - socket::{FilterConfig, Socket}, + socket::{FilterConfig, Outbound, Socket}, Enr, }; use delay_map::HashMapDelay; @@ -56,16 +59,17 @@ use tracing::{debug, error, trace, warn}; mod active_requests; mod crypto; +mod nat; mod request_call; mod session; mod tests; -pub use crate::node_info::{NodeAddress, NodeContact}; - use crate::metrics::METRICS; +pub use crate::node_info::{NodeAddress, NodeContact}; use crate::{lru_time_cache::LruTimeCache, socket::ListenConfig}; use active_requests::ActiveRequests; +use nat::Nat; use request_call::RequestCall; use session::Session; @@ -80,8 +84,7 @@ const ONE_TIME_SESSION_TIMEOUT: u64 = 30; const ONE_TIME_SESSION_CACHE_CAPACITY: usize = 100; /// Messages sent from the application layer to `Handler`. -#[derive(Debug, Clone, PartialEq)] -#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum HandlerIn { /// A Request to send to a `NodeContact` has been received from the application layer. A /// `NodeContact` is an abstract type that allows for either an ENR to be sent or a `Raw` type @@ -104,11 +107,12 @@ pub enum HandlerIn { /// response back to the `NodeAddress` from which the request was received. Response(NodeAddress, Box), - /// A Random packet has been received and we have requested the application layer to inform - /// us what the highest known ENR is for this node. - /// The `WhoAreYouRef` is sent out in the `HandlerOut::WhoAreYou` event and should - /// be returned here to submit the application's response. - WhoAreYou(WhoAreYouRef, Option), + /// The application layer is responding with an ENR to a `RequestEnr` request. This function + /// returns the requested data and optionally and ENR if one is found. + EnrResponse(Option, EnrRequestData), + + /// Observed socket has been update. The old socket and the current socket. + SocketUpdate(Option, SocketAddr), } /// Messages sent between a node on the network and `Handler`. @@ -127,14 +131,23 @@ pub enum HandlerOut { /// A Response has been received from a node on the network. Response(NodeAddress, Box), - /// An unknown source has requested information from us. Return the reference with the known - /// ENR of this node (if known). See the `HandlerIn::WhoAreYou` variant. - WhoAreYou(WhoAreYouRef), + /// We need to request the ENR of a specific node. This could be due to an unknown ENR or a + /// hole punch request. + RequestEnr(EnrRequestData), /// An RPC request failed. /// /// This returns the request ID and an error indicating why the request failed. RequestFailed(RequestId, RequestError), + + /// Triggers a ping to all peers, outside of the regular ping interval. Needed to trigger + /// renewed session establishment after updating the local ENR from unreachable to reachable + /// and clearing all sessions. Only this way does the local node have a chance to make it into + /// its peers kbuckets before the session expires (defaults to 24 hours). This is the case + /// since its peers, running this implementation, will only respond to PINGs from nodes in its + /// kbucktes and unreachable ENRs don't make it into kbuckets upon [`HandlerOut::Established`] + /// event. + PingAllPeers, } /// How we connected to the node. @@ -146,6 +159,19 @@ pub enum ConnectionDirection { Outgoing, } +/// The kind of request data being sent to the service. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EnrRequestData { + /// A Random packet has been received and request the application layer to inform + /// us what the highest known ENR is for this node. + /// The `WhoAreYouRef` is sent out in the `HandlerOut::WhoAreYou` event and should + /// be returned here to submit the application's response. + WhoAreYou(WhoAreYouRef), + /// Look-up an ENR in k-buckets. Passes the node id of the peer to look up and the + /// [`RelayMsgNotification`] we intend to send to it. + Nat(RelayInitNotification), +} + /// A reference for the application layer to send back when the handler requests any known /// ENR for the NodeContact. #[derive(Debug, Clone, PartialEq, Eq)] @@ -209,6 +235,8 @@ pub struct Handler { socket: Socket, /// Exit channel to shutdown the handler. exit: oneshot::Receiver<()>, + /// Struct to handle nat hole punching logic. + nat: Nat, } type HandlerReturn = ( @@ -237,16 +265,33 @@ impl Handler { // The local node id let node_id = enr.read().node_id(); + let Discv5Config { + enable_packet_filter, + filter_rate_limiter, + filter_max_nodes_per_ip, + filter_max_bans_per_ip, + listen_config, + executor, + ban_duration, + session_cache_capacity, + session_timeout, + unreachable_enr_limit, + unused_port_range, + request_retries, + request_timeout, + .. + } = config; + // enable the packet filter if required let filter_config = FilterConfig { - enabled: config.enable_packet_filter, - rate_limiter: config.filter_rate_limiter.clone(), - max_nodes_per_ip: config.filter_max_nodes_per_ip, - max_bans_per_ip: config.filter_max_bans_per_ip, + enabled: enable_packet_filter, + rate_limiter: filter_rate_limiter, + max_nodes_per_ip: filter_max_nodes_per_ip, + max_bans_per_ip: filter_max_bans_per_ip, }; let mut listen_sockets = SmallVec::default(); - match config.listen_config { + match listen_config { ListenConfig::Ipv4 { ip, port } => listen_sockets.push((ip, port).into()), ListenConfig::Ipv6 { ip, port } => listen_sockets.push((ip, port).into()), ListenConfig::DualStack { @@ -260,44 +305,54 @@ impl Handler { } }; + let ip_mode = listen_config.ip_mode(); + let socket_config = socket::SocketConfig { - executor: config.executor.clone().expect("Executor must exist"), + executor: executor.clone().expect("Executor must exist"), filter_config, - listen_config: config.listen_config.clone(), + listen_config, local_node_id: node_id, expected_responses: filter_expected_responses.clone(), - ban_duration: config.ban_duration, + ban_duration, }; // Attempt to bind to the socket before spinning up the send/recv tasks. let socket = Socket::new::

(socket_config).await?; - config - .executor - .clone() + let sessions = LruTimeCache::new(session_timeout, Some(session_cache_capacity)); + + let nat = Nat::new( + &listen_sockets, + &enr.read(), + ip_mode, + unused_port_range, + ban_duration, + session_cache_capacity, + unreachable_enr_limit, + ); + + executor .expect("Executor must be present") .spawn(Box::pin(async move { let mut handler = Handler { - request_retries: config.request_retries, + request_retries, node_id, enr, key, - active_requests: ActiveRequests::new(config.request_timeout), + active_requests: ActiveRequests::new(request_timeout), pending_requests: HashMap::new(), filter_expected_responses, - sessions: LruTimeCache::new( - config.session_timeout, - Some(config.session_cache_capacity), - ), + sessions, one_time_sessions: LruTimeCache::new( Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), Some(ONE_TIME_SESSION_CACHE_CAPACITY), ), - active_challenges: HashMapDelay::new(config.request_timeout), + active_challenges: HashMapDelay::new(request_timeout), service_recv, service_send, listen_sockets, socket, + nat, exit, }; debug!("Handler Starting"); @@ -325,20 +380,65 @@ impl Handler { } } HandlerIn::Response(dst, response) => self.send_response::

(dst, *response).await, - HandlerIn::WhoAreYou(wru_ref, enr) => self.send_challenge::

(wru_ref, enr).await, + HandlerIn::EnrResponse(enr, EnrRequestData::WhoAreYou(wru_ref)) => self.send_challenge::

(wru_ref, enr).await, + HandlerIn::EnrResponse(Some(target_enr), EnrRequestData::Nat(relay_initiation)) => { + // Assemble the notification for the target + let (initiator_enr, _target, timed_out_nonce) = relay_initiation.into(); + let relay_msg_notification = RelayMsgNotification::new(initiator_enr, timed_out_nonce); + if let Err(e) = self.send_relay_msg_notification::

(target_enr, relay_msg_notification).await { + warn!("Failed to relay. Error: {:?}", e); + } + } + HandlerIn::EnrResponse(_,_) => {} // This handles the case that No ENR was + // found for a target relayer. This + // message never gets sent, so it is + // ignored. + HandlerIn::SocketUpdate(old_socket, socket) => { + let ip = socket.ip(); + let port = socket.port(); + if old_socket.is_none() { + // This node goes from being unreachable to being reachable, but + // keeps the same enr key (hence same node id). Remove its + // sessions to trigger a WHOAREYOU from peers on next sent + // message. If the peer is running this implementation of + // discovery, this makes it possible for the local node to be + // inserted into its peers' kbuckets before the session they + // already had expires. Session duration, in this impl defaults to + // 24 hours. + self.sessions.clear(); + if let Err(e) = self + .service_send + .send(HandlerOut::PingAllPeers) + .await + { + warn!("Failed to inform that request failed {}", e); + } + } + self.nat.set_is_behind_nat(&self.listen_sockets, Some(ip), Some(port)); + } } } Some(inbound_packet) = self.socket.recv.recv() => { self.process_inbound_packet::

(inbound_packet).await; } Some(Ok((node_address, pending_request))) = self.active_requests.next() => { - self.handle_request_timeout(node_address, pending_request).await; + self.handle_request_timeout::

(node_address, pending_request).await; } Some(Ok((node_address, _challenge))) = self.active_challenges.next() => { // A challenge has expired. There could be pending requests awaiting this // challenge. We process them here self.send_next_request::

(node_address).await; } + Some(Ok(peer_socket)) = self.nat.hole_punch_tracker.next() => { + if self.nat.is_behind_nat == Some(false) { + // Until ip voting is done and an observed public address is finalised, all nodes act as + // if they are behind a NAT. + return; + } + if let Err(e) = self.on_hole_punch_expired(peer_socket).await { + warn!("Failed to keep hole punched for peer, error: {:?}", e); + } + } _ = banned_nodes_check.tick() => self.unban_nodes_check(), // Unban nodes that are past the timeout _ = &mut self.exit => { return; @@ -400,6 +500,19 @@ impl Handler { ) .await } + PacketKind::SessionMessage { src_id } => { + let node_address = NodeAddress { + socket_addr: inbound_packet.src_address, + node_id: src_id, + }; + self.handle_session_message::

( + node_address, + message_nonce, + &inbound_packet.message, + &inbound_packet.authenticated_data, + ) + .await + } } } @@ -424,13 +537,42 @@ impl Handler { } /// A request has timed out. - async fn handle_request_timeout( + async fn handle_request_timeout( &mut self, node_address: NodeAddress, mut request_call: RequestCall, ) { if request_call.retries() >= self.request_retries { trace!("Request timed out with {}", node_address); + if let Some(relay) = self + .nat + .new_peer_latest_relay_cache + .pop(&node_address.node_id) + { + // The request might be timing out because the peer is behind a NAT. If we + // have a relay to the peer, attempt NAT hole punching. + let target = request_call.contact().node_address(); + trace!("Trying to hole punch target {target} with relay {relay}"); + let local_enr = self.enr.read().clone(); + let nonce = request_call.packet().header.message_nonce; + match self + .on_request_time_out::

(relay, local_enr, nonce, target) + .await + { + Err(NatError::Initiator(Discv5Error::SessionAlreadyEstablished( + node_address, + ))) => { + debug!("Session to peer already established, aborting hole punch attempt. Peer: {node_address}"); + } + Err(e) => { + warn!("Failed to start hole punching. Error: {:?}", e); + } + Ok(()) => { + self.active_requests.insert(node_address, request_call); + return; + } + } + } // Remove the request from the awaiting packet_filter self.remove_expected_response(node_address.socket_addr); // The request has timed out. We keep any established session for future use. @@ -464,14 +606,15 @@ impl Handler { return Err(RequestError::SelfRequest); } - // If there is already an active request or an active challenge (WHOAREYOU sent) for this node, add to pending requests + // If there is already an active request or an active challenge (WHOAREYOU sent) for this + // node, add to pending requests if self.active_requests.get(&node_address).is_some() || self.active_challenges.get(&node_address).is_some() { trace!("Request queued for node: {}", node_address); self.pending_requests .entry(node_address) - .or_insert_with(Vec::new) + .or_default() .push(PendingRequest { contact, request_id, @@ -529,10 +672,10 @@ impl Handler { ) { // Check for an established session let packet = if let Some(session) = self.sessions.get_mut(&node_address) { - session.encrypt_message::

(self.node_id, &response.encode()) + session.encrypt_session_message::

(self.node_id, &response.encode()) } else if let Some(mut session) = self.remove_one_time_session(&node_address, &response.id) { - session.encrypt_message::

(self.node_id, &response.encode()) + session.encrypt_session_message::

(self.node_id, &response.encode()) } else { // Either the session is being established or has expired. We simply drop the // response in this case. @@ -688,6 +831,10 @@ impl Handler { // All sent requests must have an associated node_id. Therefore the following // must not panic. let node_address = request_call.contact().node_address(); + + // Keep track if the ENR is reachable. In the case we don't know the ENR, we assume its + // fine. + let mut enr_not_reachable = false; match request_call.contact().enr() { Some(enr) => { // NOTE: Here we decide if the session is outgoing or ingoing. The condition for an @@ -700,6 +847,8 @@ impl Handler { ConnectionDirection::Incoming }; + enr_not_reachable = Nat::is_enr_reachable(&enr); + // We already know the ENR. Send the handshake response packet trace!("Sending Authentication response to node: {}", node_address); request_call.update_packet(auth_packet.clone()); @@ -711,14 +860,8 @@ impl Handler { self.send(node_address.clone(), auth_packet).await; // Notify the application that the session has been established - self.service_send - .send(HandlerOut::Established( - enr, - node_address.socket_addr, - connection_direction, - )) - .await - .unwrap_or_else(|e| warn!("Error with sending channel: {}", e)); + self.new_connection(enr, node_address.socket_addr, connection_direction) + .await; } None => { // Don't know the ENR. Establish the session, but request an ENR also @@ -743,7 +886,7 @@ impl Handler { } } } - self.new_session(node_address, session); + self.new_session(node_address, session, enr_not_reachable); } /// Verifies a Node ENR to it's observed address. If it fails, any associated session is also @@ -783,14 +926,35 @@ impl Handler { ); if let Some(challenge) = self.active_challenges.remove(&node_address) { + // Find the most recent ENR, a known ENR or one they sent in their challenge. + let Challenge { data, remote_enr } = challenge; + let Ok(most_recent_enr) = most_recent_enr(enr_record, remote_enr) else { + warn!( + "Peer did not respond with their ENR. Session could not be established. Node: {}",node_address + ); + self.fail_session(&node_address, RequestError::InvalidRemotePacket, true) + .await; + return; + }; + + // Keep count of the unreachable Sessions we are tracking + // Peer is reachable + let enr_not_reachable = !Nat::is_enr_reachable(&most_recent_enr); + + // Decide whether to establish this connection based on our appetite for unreachable + if enr_not_reachable && Some(self.sessions.tagged()) >= self.nat.unreachable_enr_limit { + debug!("Reached limit of unreachable ENR sessions. Avoiding a new connection. Limit: {}", self.sessions.tagged()); + return; + } + match Session::establish_from_challenge( self.key.clone(), &self.node_id, &node_address.node_id, - challenge, + data, id_nonce_sig, ephem_pubkey, - enr_record, + most_recent_enr, ) { Ok((mut session, enr)) => { // Receiving an AuthResponse must give us an up-to-date view of the node ENR. @@ -800,18 +964,16 @@ impl Handler { // Notify the application // The session established here are from WHOAREYOU packets that we sent. // This occurs when a node established a connection with us. - if let Err(e) = self - .service_send - .send(HandlerOut::Established( - enr, - node_address.socket_addr, - ConnectionDirection::Incoming, - )) - .await - { - warn!("Failed to inform of established session {}", e) - } - self.new_session(node_address.clone(), session); + self.new_connection( + enr, + node_address.socket_addr, + ConnectionDirection::Incoming, + ) + .await; + self.new_session(node_address.clone(), session, enr_not_reachable); + self.nat + .new_peer_latest_relay_cache + .pop(&node_address.node_id); self.handle_message::

( node_address.clone(), message_nonce, @@ -933,8 +1095,94 @@ impl Handler { } } + /// Handle a session message packet, that is dropped if it can't be decrypted. + async fn handle_session_message( + &mut self, + node_address: NodeAddress, // session message sender + message_nonce: MessageNonce, + message: &[u8], + authenticated_data: &[u8], + ) { + // check if we have an available session + let Some(session) = self.sessions.get_mut(&node_address) else { + warn!( + "Dropping message. Error: {}, {}", + Discv5Error::SessionNotEstablished, + node_address + ); + return; + }; + // attempt to decrypt notification (same decryption as for a message) + let message = match session.decrypt_message(message_nonce, message, authenticated_data) { + Err(e) => { + // We have a session, but the session message could not be decrypted. It is + // likely the node sending this message has dropped their session. Since + // this is a session message that assumes an established session, we do + // not reply with a WHOAREYOU to this random packet. This means we drop + // the packet. + warn!( + "Dropping message that should have been part of a session. Error: {}", + e + ); + return; + } + Ok(ref bytes) => match Message::decode(bytes) { + Ok(message) => message, + Err(err) => { + warn!( + "Failed to decode message. Error: {:?}, {}", + err, node_address + ); + return; + } + }, + }; + + match message { + Message::Response(response) => self.handle_response::

(node_address, response).await, + Message::RelayInitNotification(notification) => { + let initiator_node_id = notification.initiator_enr().node_id(); + if initiator_node_id != node_address.node_id { + warn!("peer {node_address} tried to initiate hole punch attempt for another node {initiator_node_id}, banning peer {node_address}"); + self.fail_session(&node_address, RequestError::MaliciousRelayInit, true) + .await; + let ban_timeout = self.nat.ban_duration.map(|v| Instant::now() + v); + PERMIT_BAN_LIST.write().ban(node_address, ban_timeout); + } else if let Err(e) = self.on_relay_initiation(notification).await { + warn!( + "failed handling notification to relay for {node_address}, {:?}", + e + ); + } + } + Message::RelayMsgNotification(notification) => { + match self.nat.is_behind_nat { + Some(false) => { + // inr may not be malicious and initiated a hole punch attempt when + // a request to this node timed out for another reason + debug!("peer {node_address} relayed a hole punch notification but we are not behind nat"); + } + _ => { + if let Err(e) = self.on_relay_msg::

(notification).await { + warn!( + "failed handling notification relayed from {node_address}, {:?}", + e + ); + } + } + } + } + Message::Request(_) => { + warn!( + "Peer sent message type {} that shouldn't be sent in packet type `Session Message`, {}", + message.msg_type(), + node_address, + ); + } + } + } + /// Handle a standard message that does not contain an authentication header. - #[allow(clippy::single_match)] async fn handle_message( &mut self, node_address: NodeAddress, @@ -972,7 +1220,9 @@ impl Handler { let whoareyou_ref = WhoAreYouRef(node_address, message_nonce); if let Err(e) = self .service_send - .send(HandlerOut::WhoAreYou(whoareyou_ref)) + .send(HandlerOut::RequestEnr(EnrRequestData::WhoAreYou( + whoareyou_ref, + ))) .await { warn!("Failed to send WhoAreYou to the service {}", e) @@ -986,7 +1236,6 @@ impl Handler { trace!("Received message from: {}", node_address); - // Remove any associated request from pending_request match message { Message::Request(request) => { // report the request to the application @@ -999,45 +1248,16 @@ impl Handler { } } Message::Response(response) => { - // Sessions could be awaiting an ENR response. Check if this response matches - // these - if let Some(request_id) = session.awaiting_enr.as_ref() { - if &response.id == request_id { - session.awaiting_enr = None; - match response.body { - ResponseBody::Nodes { mut nodes, .. } => { - // Received the requested ENR - if let Some(enr) = nodes.pop() { - if self.verify_enr(&enr, &node_address) { - // Notify the application - // This can occur when we try to dial a node without an - // ENR. In this case we have attempted to establish the - // connection, so this is an outgoing connection. - if let Err(e) = self - .service_send - .send(HandlerOut::Established( - enr, - node_address.socket_addr, - ConnectionDirection::Outgoing, - )) - .await - { - warn!("Failed to inform established outgoing connection {}", e) - } - return; - } - } - } - _ => {} - } - debug!("Session failed invalid ENR response"); - self.fail_session(&node_address, RequestError::InvalidRemoteEnr, true) - .await; - return; - } - } - // Handle standard responses - self.handle_response::

(node_address, response).await; + // Accept response in Message packet for backwards compatibility + warn!("Received a response in a `Message` packet, should be sent in a `SessionMessage`"); + self.handle_response::

(node_address, response).await + } + Message::RelayInitNotification(_) | Message::RelayMsgNotification(_) => { + warn!( + "Peer sent message type {} that shouldn't be sent in packet type `Message`, {}", + message.msg_type(), + node_address + ); } } } else { @@ -1048,7 +1268,9 @@ impl Handler { let whoareyou_ref = WhoAreYouRef(node_address, message_nonce); if let Err(e) = self .service_send - .send(HandlerOut::WhoAreYou(whoareyou_ref)) + .send(HandlerOut::RequestEnr(EnrRequestData::WhoAreYou( + whoareyou_ref, + ))) .await { warn!( @@ -1066,6 +1288,49 @@ impl Handler { node_address: NodeAddress, response: Response, ) { + // Sessions could be awaiting an ENR response. Check if this response matches + // this + // check if we have an available session + let Some(session) = self.sessions.get_mut(&node_address) else { + warn!( + "Dropping response. Error: {}, {}", + Discv5Error::SessionNotEstablished, + node_address + ); + return; + }; + + if let Some(request_id) = session.awaiting_enr.as_ref() { + if &response.id == request_id { + session.awaiting_enr = None; + if let ResponseBody::Nodes { mut nodes, .. } = response.body { + // Received the requested ENR + let Some(enr) = nodes.pop() else { + return; + }; + if self.verify_enr(&enr, &node_address) { + // Notify the application + // This can occur when we try to dial a node without an + // ENR. In this case we have attempted to establish the + // connection, so this is an outgoing connection. + self.new_connection( + enr, + node_address.socket_addr, + ConnectionDirection::Outgoing, + ) + .await; + return; + } + } + debug!("Session failed invalid ENR response"); + self.fail_session(&node_address, RequestError::InvalidRemoteEnr, true) + .await; + return; + } + } + + // Handle standard responses + // Find a matching request, if any if let Some(mut request_call) = self.active_requests.remove(&node_address) { let id = match request_call.id() { @@ -1085,7 +1350,21 @@ impl Handler { // Check to see if this is a Nodes response, in which case we may require to wait for // extra responses - if let ResponseBody::Nodes { total, .. } = response.body { + if let ResponseBody::Nodes { total, ref nodes } = response.body { + for node in nodes { + if let Some(socket_addr) = self.nat.ip_mode.get_contactable_addr(node) { + let node_id = node.node_id(); + let new_peer_node_address = NodeAddress { + socket_addr, + node_id, + }; + if self.sessions.peek(&new_peer_node_address).is_none() { + self.nat + .new_peer_latest_relay_cache + .put(node_id, node_address.clone()); + } + } + } if total > 1 { // This is a multi-response Nodes response if let Some(remaining_responses) = request_call.remaining_responses_mut() { @@ -1152,11 +1431,18 @@ impl Handler { self.active_requests.insert(node_address, request_call); } - fn new_session(&mut self, node_address: NodeAddress, session: Session) { + /// Updates the session cache for a new session. + fn new_session( + &mut self, + node_address: NodeAddress, + session: Session, + enr_not_reachable: bool, + ) { if let Some(current_session) = self.sessions.get_mut(&node_address) { current_session.update(session); } else { - self.sessions.insert(node_address, session); + self.sessions + .insert_raw(node_address, session, enr_not_reachable); METRICS .active_sessions .store(self.sessions.len(), Ordering::Relaxed); @@ -1204,8 +1490,10 @@ impl Handler { } } } - let node_address = request_call.contact().node_address(); + self.nat + .new_peer_latest_relay_cache + .pop(&node_address.node_id); self.fail_session(&node_address, error, remove_session) .await; } @@ -1222,6 +1510,8 @@ impl Handler { METRICS .active_sessions .store(self.sessions.len(), Ordering::Relaxed); + // stop keeping hole punched for peer + self.nat.untrack(&node_address.socket_addr); } if let Some(to_remove) = self.pending_requests.remove(node_address) { for PendingRequest { request_id, .. } in to_remove { @@ -1243,15 +1533,22 @@ impl Handler { } } - /// Sends a packet to the send handler to be encoded and sent. + /// Assembles and sends a [`Packet`]. async fn send(&mut self, node_address: NodeAddress, packet: Packet) { let outbound_packet = socket::OutboundPacket { node_address, packet, }; - if let Err(e) = self.socket.send.send(outbound_packet).await { + self.send_outbound(outbound_packet.into()).await; + } + + /// Sends a packet to the send handler to be encoded and sent. + async fn send_outbound(&mut self, packet: Outbound) { + let dst = *packet.dst(); + if let Err(e) = self.socket.send.send(packet).await { warn!("Failed to send outbound packet {}", e) } + self.nat.track(dst); } /// Check if any banned nodes have served their time and unban them. @@ -1265,4 +1562,186 @@ impl Handler { .ban_nodes .retain(|_, time| time.is_none() || Some(Instant::now()) < *time); } + + async fn new_connection( + &mut self, + enr: Enr, + socket_addr: SocketAddr, + conn_dir: ConnectionDirection, + ) { + if let Err(e) = self + .service_send + .send(HandlerOut::Established(enr, socket_addr, conn_dir)) + .await + { + warn!( + "Failed to inform of established connection {}, {}", + conn_dir, e + ) + } + } +} + +/// Given two optional ENRs, find the most recent one based on the sequence number. +/// This function will error if both inputs are None. +fn most_recent_enr(first: Option, second: Option) -> Result { + match (first, second) { + (Some(first_enr), Some(second_enr)) => { + if first_enr.seq() > second_enr.seq() { + Ok(first_enr) + } else { + Ok(second_enr) + } + } + (Some(first), None) => Ok(first), + (None, Some(second)) => Ok(second), + (None, None) => Err(()), // No ENR provided + } +} + +// NAT-related functions +impl Handler { + /// A request times out. Should trigger the initiation of a hole punch attempt, given a + /// transitive route to the target exists. Sends a RELAYINIT notification to the given + /// relay. + async fn on_request_time_out( + &mut self, + relay: NodeAddress, + local_enr: Enr, // initiator-enr + timed_out_nonce: MessageNonce, + target_node_address: NodeAddress, + ) -> Result<(), NatError> { + // Another hole punch process with this target may have just completed. + if self.sessions.get(&target_node_address).is_some() { + return Err(NatError::Initiator(Discv5Error::SessionAlreadyEstablished( + target_node_address, + ))); + } + if let Some(session) = self.sessions.get_mut(&relay) { + let relay_init_notif = + RelayInitNotification::new(local_enr, target_node_address.node_id, timed_out_nonce); + trace!( + "Sending notif to relay {}. relay init: {}", + relay.node_id, + relay_init_notif, + ); + // Encrypt the message and send + let packet = match session + .encrypt_session_message::

(self.node_id, &relay_init_notif.encode()) + { + Ok(packet) => packet, + Err(e) => { + return Err(NatError::Initiator(e)); + } + }; + self.send(relay, packet).await; + } else { + // Drop hole punch attempt with this relay, to ensure hole punch round-trip time stays + // within the time out of the udp entrypoint for the target peer in the initiator's + // router, set by the original timed out FINDNODE request from the initiator, as the + // initiator may also be behind a NAT. + warn!( + "Session is not established. Dropping relay notification for relay: {}", + relay.node_id + ); + } + Ok(()) + } + + /// A RelayInit notification is received over discv5 indicating this node is the relay. Should + /// trigger sending a RelayMsg to the target. + async fn on_relay_initiation( + &mut self, + relay_initiation: RelayInitNotification, + ) -> Result<(), NatError> { + // Check for target peer in our kbuckets otherwise drop notification. + if let Err(e) = self + .service_send + .send(HandlerOut::RequestEnr(EnrRequestData::Nat( + relay_initiation, + ))) + .await + { + return Err(NatError::Relay(e.into())); + } + Ok(()) + } + + /// A RelayMsg notification is received over discv5 indicating this node is the target. Should + /// trigger a WHOAREYOU to be sent to the initiator using the `nonce` in the RelayMsg. + async fn on_relay_msg( + &mut self, + relay_msg: RelayMsgNotification, + ) -> Result<(), NatError> { + let (inr_enr, timed_out_msg_nonce) = relay_msg.into(); + let initiator_node_address = match NodeContact::try_from_enr(inr_enr, self.nat.ip_mode) { + Ok(contact) => contact.node_address(), + Err(e) => return Err(NatError::Target(e.into())), + }; + + // A session may already have been established. + if self.sessions.get(&initiator_node_address).is_some() { + trace!("Session already established with initiator: {initiator_node_address}"); + return Ok(()); + } + // Possibly, an attempt to punch this hole, using another relay, is in progress. + if self + .active_challenges + .get(&initiator_node_address) + .is_some() + { + trace!("WHOAREYOU packet already sent to initiator: {initiator_node_address}"); + return Ok(()); + } + + // If not hole punch attempts are in progress, spawn a WHOAREYOU event to punch a hole in + // our NAT for initiator. + let whoareyou_ref = WhoAreYouRef(initiator_node_address, timed_out_msg_nonce); + self.send_challenge::

(whoareyou_ref, None).await; + + Ok(()) + } + + /// Send a RELAYMSG notification. + async fn send_relay_msg_notification( + &mut self, + target_enr: Enr, + relay_msg_notification: RelayMsgNotification, + ) -> Result<(), NatError> { + let target_node_address = match NodeContact::try_from_enr(target_enr, self.nat.ip_mode) { + Ok(contact) => contact.node_address(), + Err(e) => return Err(NatError::Relay(e.into())), + }; + if let Some(session) = self.sessions.get_mut(&target_node_address) { + trace!( + "Sending notification to target {}. relay msg: {}", + target_node_address.node_id, + relay_msg_notification, + ); + // Encrypt the notification and send + let packet = match session + .encrypt_session_message::

(self.node_id, &relay_msg_notification.encode()) + { + Ok(packet) => packet, + Err(e) => { + return Err(NatError::Relay(e)); + } + }; + self.send(target_node_address, packet).await; + Ok(()) + } else { + // Either the session is being established or has expired. We simply drop the + // notification in this case to ensure hole punch round-trip time stays within the + // time out of the udp entrypoint for the target peer in the initiator's NAT, set by + // the original timed out FINDNODE request from the initiator, as the initiator may + // also be behind a NAT. + Err(NatError::Relay(Discv5Error::SessionNotEstablished)) + } + } + + #[inline] + async fn on_hole_punch_expired(&mut self, peer: SocketAddr) -> Result<(), NatError> { + self.send_outbound(peer.into()).await; + Ok(()) + } } diff --git a/src/handler/nat.rs b/src/handler/nat.rs new file mode 100644 index 000000000..b9e859f65 --- /dev/null +++ b/src/handler/nat.rs @@ -0,0 +1,190 @@ +use std::{ + net::{IpAddr, SocketAddr, UdpSocket}, + ops::RangeInclusive, + time::Duration, +}; + +use delay_map::HashSetDelay; +use enr::NodeId; +use lru::LruCache; +use rand::Rng; + +use crate::{node_info::NodeAddress, Enr, IpMode}; + +/// The expected shortest lifetime in most NAT configurations of a punched hole in seconds. +pub const DEFAULT_HOLE_PUNCH_LIFETIME: u64 = 20; +/// The default number of ports to try before concluding that the local node is behind NAT. +pub const PORT_BIND_TRIES: usize = 4; +/// Port range that is not impossible to bind to. +pub const USER_AND_DYNAMIC_PORTS: RangeInclusive = 1025..=u16::MAX; + +/// Aggregates types necessary to implement nat hole punching for [`crate::handler::Handler`]. +pub struct Nat { + /// Ip mode as set in config. + pub ip_mode: IpMode, + /// This node has been observed to be behind a NAT. + pub is_behind_nat: Option, + /// The last peer to send us a new peer in a NODES response is stored as the new peer's + /// potential relay until the first request to the new peer after its discovery is either + /// responded or failed. The cache will usually be emptied by successful or failed session + /// establishment, but for the edge case that a NODES response is returned for an ended query + /// and hence an attempt to establish a session with those nodes isn't initiated, a bound on + /// the relay cache is set equivalent to the Handler's `session_cache_capacity`. + pub new_peer_latest_relay_cache: LruCache, + /// Keeps track if this node needs to send a packet to a peer in order to keep a hole punched + /// for it in its NAT. + pub hole_punch_tracker: HashSetDelay, + /// Ports to trie to bind to check if this node is behind NAT. + pub unused_port_range: Option>, + /// If the filter is enabled this sets the default timeout for bans enacted by the filter. + pub ban_duration: Option, + /// The number of unreachable ENRs we store at most in our session cache. + pub unreachable_enr_limit: Option, +} + +impl Nat { + pub fn new( + listen_sockets: &[SocketAddr], + local_enr: &Enr, + ip_mode: IpMode, + unused_port_range: Option>, + ban_duration: Option, + session_cache_capacity: usize, + unreachable_enr_limit: Option, + ) -> Self { + let mut nat = Nat { + ip_mode, + is_behind_nat: None, + new_peer_latest_relay_cache: LruCache::new(session_cache_capacity), + hole_punch_tracker: HashSetDelay::new(Duration::from_secs(DEFAULT_HOLE_PUNCH_LIFETIME)), + unused_port_range, + ban_duration, + unreachable_enr_limit, + }; + // Optimistically only test one advertised socket, ipv4 has precedence. If it is + // reachable, assumption is made that also the other ip version socket is reachable. + match ( + local_enr.ip4(), + local_enr.udp4(), + local_enr.ip6(), + local_enr.udp6(), + ) { + (Some(ip), port, _, _) => { + nat.set_is_behind_nat(listen_sockets, Some(ip.into()), port); + } + (_, _, Some(ip6), port) => { + nat.set_is_behind_nat(listen_sockets, Some(ip6.into()), port); + } + (None, Some(port), _, _) | (_, _, None, Some(port)) => { + nat.set_is_behind_nat(listen_sockets, None, Some(port)); + } + (None, None, None, None) => {} + } + nat + } + + pub fn track(&mut self, peer_socket: SocketAddr) { + if self.is_behind_nat == Some(false) { + return; + } + self.hole_punch_tracker.insert(peer_socket); + } + + pub fn untrack(&mut self, peer_socket: &SocketAddr) { + _ = self.hole_punch_tracker.remove(peer_socket) + } + + /// Called when a new observed address is reported at start up or after a + /// [`crate::Discv5Event::SocketUpdated`]. + pub fn set_is_behind_nat( + &mut self, + listen_sockets: &[SocketAddr], + observed_ip: Option, + observed_port: Option, + ) { + if !listen_sockets + .iter() + .any(|listen_socket| Some(listen_socket.port()) == observed_port) + { + self.is_behind_nat = Some(true); + return; + } + + // Without and observed IP it is too early to conclude if the local node is behind a NAT, + // return. + let Some(ip) = observed_ip else { + return; + }; + + self.is_behind_nat = Some(match is_behind_nat(ip, &self.unused_port_range) { + true => true, + false => { + // node assume it is behind NAT until now + self.hole_punch_tracker.clear(); + false + } + }); + } + + /// Determines if an ENR is reachable or not based on its assigned keys. + pub fn is_enr_reachable(enr: &Enr) -> bool { + enr.udp4_socket().is_some() || enr.udp6_socket().is_some() + } +} + +/// Helper function to test if the local node is behind NAT based on the node's observed reachable +/// socket. +fn is_behind_nat(observed_ip: IpAddr, unused_port_range: &Option>) -> bool { + // If the node cannot bind to the observed address at any of some random ports, we + // conclude it is behind NAT. + let mut rng = rand::thread_rng(); + let unused_port_range = match unused_port_range { + Some(range) => range, + None => &USER_AND_DYNAMIC_PORTS, + }; + for _ in 0..PORT_BIND_TRIES { + let rnd_port: u16 = rng.gen_range(unused_port_range.clone()); + if UdpSocket::bind((observed_ip, rnd_port)).is_ok() { + return false; + } + } + true +} + +#[cfg(test)] +mod test { + use crate::return_if_ipv6_is_not_supported; + + use super::*; + + #[test] + fn test_is_not_behind_nat() { + assert!(!is_behind_nat(IpAddr::from([127, 0, 0, 1]), &None)); + } + + #[test] + fn test_is_behind_nat() { + assert!(is_behind_nat(IpAddr::from([8, 8, 8, 8]), &None)); + } + + // ipv6 tests don't run in github ci https://github.com/actions/runner-images/issues/668 + #[test] + fn test_is_not_behind_nat_ipv6() { + return_if_ipv6_is_not_supported!(); + + assert!(!is_behind_nat( + IpAddr::from([0u16, 0u16, 0u16, 0u16, 0u16, 0u16, 0u16, 1u16]), + &None, + )); + } + + // ipv6 tests don't run in github ci https://github.com/actions/runner-images/issues/668 + #[test] + fn test_is_behind_nat_ipv6() { + // google's ipv6 + assert!(is_behind_nat( + IpAddr::from([2001, 4860, 4860, 0u16, 0u16, 0u16, 0u16, 0u16]), + &None, + )); + } +} diff --git a/src/handler/request_call.rs b/src/handler/request_call.rs index b91c7643e..a776dfdb1 100644 --- a/src/handler/request_call.rs +++ b/src/handler/request_call.rs @@ -1,7 +1,7 @@ -pub use crate::node_info::{NodeAddress, NodeContact}; +pub use crate::node_info::NodeContact; use crate::{ packet::Packet, - rpc::{Request, RequestBody}, + rpc::{Payload, Request, RequestBody}, }; use super::HandlerReqId; diff --git a/src/handler/session.rs b/src/handler/session.rs index 9f79f9017..d0f53f305 100644 --- a/src/handler/session.rs +++ b/src/handler/session.rs @@ -1,11 +1,18 @@ use super::*; use crate::{ + handler::Challenge, node_info::NodeContact, packet::{ - ChallengeData, Packet, PacketHeader, PacketKind, ProtocolIdentity, MESSAGE_NONCE_LENGTH, + ChallengeData, MessageNonce, Packet, PacketHeader, PacketKind, ProtocolIdentity, + MESSAGE_NONCE_LENGTH, }, + rpc::RequestId, + Discv5Error, Enr, }; + use enr::{CombinedKey, NodeId}; +use parking_lot::RwLock; +use std::sync::Arc; use zeroize::Zeroize; #[derive(Zeroize, PartialEq)] @@ -16,6 +23,15 @@ pub(crate) struct Keys { decryption_key: [u8; 16], } +impl From<([u8; 16], [u8; 16])> for Keys { + fn from((encryption_key, decryption_key): ([u8; 16], [u8; 16])) -> Self { + Keys { + encryption_key, + decryption_key, + } + } +} + /// A Session containing the encryption/decryption keys. These are kept individually for a given /// node. pub(crate) struct Session { @@ -55,17 +71,33 @@ impl Session { self.awaiting_enr = new_session.awaiting_enr; } - /// Uses the current `Session` to encrypt a message. Encrypt packets with the current session - /// key if we are awaiting a response from AuthMessage. + /// Uses the current `Session` to encrypt a `SessionMessage`. + pub(crate) fn encrypt_session_message( + &mut self, + src_id: NodeId, + message: &[u8], + ) -> Result { + self.encrypt::

(message, PacketKind::SessionMessage { src_id }) + } + + /// Uses the current `Session` to encrypt a `Message`. pub(crate) fn encrypt_message( &mut self, src_id: NodeId, message: &[u8], + ) -> Result { + self.encrypt::

(message, PacketKind::Message { src_id }) + } + + /// Encrypts packets with the current session key if we are awaiting a response from + /// AuthMessage. + fn encrypt( + &mut self, + message: &[u8], + packet_kind: PacketKind, ) -> Result { self.counter += 1; - // If the message nonce length is ever set below 4 bytes this will explode. The packet - // size constants shouldn't be modified. let random_nonce: [u8; MESSAGE_NONCE_LENGTH - 4] = rand::random(); let mut message_nonce: MessageNonce = [0u8; MESSAGE_NONCE_LENGTH]; message_nonce[..4].copy_from_slice(&self.counter.to_be_bytes()); @@ -75,7 +107,7 @@ impl Session { let iv: u128 = rand::random(); let header = PacketHeader { message_nonce, - kind: PacketKind::Message { src_id }, + kind: packet_kind, }; let mut authenticated_data = iv.to_be_bytes().to_vec(); @@ -132,48 +164,28 @@ impl Session { /// Generates session keys from an authentication header. If the IP of the ENR does not match the /// source IP address, we consider this session untrusted. The output returns a boolean which /// specifies if the Session is trusted or not. + #[allow(clippy::too_many_arguments)] pub(crate) fn establish_from_challenge( local_key: Arc>, local_id: &NodeId, remote_id: &NodeId, - challenge: Challenge, + challenge_data: ChallengeData, id_nonce_sig: &[u8], ephem_pubkey: &[u8], - enr_record: Option, + session_enr: Enr, ) -> Result<(Session, Enr), Discv5Error> { - // check and verify a potential ENR update - - // Duplicate code here to avoid cloning an ENR - let remote_public_key = { - let enr = match (enr_record.as_ref(), challenge.remote_enr.as_ref()) { - (Some(new_enr), Some(known_enr)) => { - if new_enr.seq() > known_enr.seq() { - new_enr - } else { - known_enr - } - } - (Some(new_enr), None) => new_enr, - (None, Some(known_enr)) => known_enr, - (None, None) => { - warn!( - "Peer did not respond with their ENR. Session could not be established. Node: {}", - remote_id - ); - return Err(Discv5Error::SessionNotEstablished); - } - }; - enr.public_key() - }; - // verify the auth header nonce if !crypto::verify_authentication_nonce( - &remote_public_key, + &session_enr.public_key(), ephem_pubkey, - &challenge.data, + &challenge_data, local_id, id_nonce_sig, ) { + let challenge = Challenge { + data: challenge_data, + remote_enr: Some(session_enr), + }; return Err(Discv5Error::InvalidChallengeSignature(challenge)); } @@ -185,7 +197,7 @@ impl Session { &local_key.read(), local_id, remote_id, - &challenge.data, + &challenge_data, ephem_pubkey, )?; @@ -194,21 +206,6 @@ impl Session { decryption_key, }; - // Takes ownership of the provided ENRs - Slightly annoying code duplication, but avoids - // cloning ENRs - let session_enr = match (enr_record, challenge.remote_enr) { - (Some(new_enr), Some(known_enr)) => { - if new_enr.seq() > known_enr.seq() { - new_enr - } else { - known_enr - } - } - (Some(new_enr), None) => new_enr, - (None, Some(known_enr)) => known_enr, - (None, None) => unreachable!("Checked in the first match above"), - }; - Ok((Session::new(keys), session_enr)) } diff --git a/src/handler/tests.rs b/src/handler/tests.rs index db9c7d3a8..c2358c394 100644 --- a/src/handler/tests.rs +++ b/src/handler/tests.rs @@ -2,21 +2,19 @@ use super::*; use crate::{ - packet::DefaultProtocolId, + handler::session::build_dummy_session, + packet::{DefaultProtocolId, PacketHeader, MAX_PACKET_SIZE, MESSAGE_NONCE_LENGTH}, return_if_ipv6_is_not_supported, rpc::{Request, Response}, Discv5ConfigBuilder, IpMode, }; use std::net::{Ipv4Addr, Ipv6Addr}; -use crate::{ - handler::{session::build_dummy_session, HandlerOut::RequestFailed}, - RequestError::SelfRequest, -}; +use crate::{handler::HandlerOut::RequestFailed, RequestError::SelfRequest}; use active_requests::ActiveRequests; use enr::EnrBuilder; use std::time::Duration; -use tokio::time::sleep; +use tokio::{net::UdpSocket, time::sleep}; fn init() { let _ = tracing_subscriber::fmt() @@ -24,16 +22,31 @@ fn init() { .try_init(); } -async fn build_handler() -> Handler { - let config = Discv5ConfigBuilder::new(ListenConfig::default()).build(); +struct MockService { + tx: mpsc::UnboundedSender, + rx: mpsc::Receiver, + exit_tx: oneshot::Sender<()>, +} + +async fn build_handler() -> (Handler, MockService) { + build_handler_with_listen_config::

(ListenConfig::default()).await +} + +async fn build_handler_with_listen_config( + listen_config: ListenConfig, +) -> (Handler, MockService) { + let listen_port = listen_config + .ipv4_port() + .expect("listen config should default to ipv4"); + let config = Discv5ConfigBuilder::new(listen_config).build(); let key = CombinedKey::generate_secp256k1(); let enr = EnrBuilder::new("v4") .ip4(Ipv4Addr::LOCALHOST) - .udp4(9000) + .udp4(listen_port) .build(&key) .unwrap(); let mut listen_sockets = SmallVec::default(); - listen_sockets.push((Ipv4Addr::LOCALHOST, 9000).into()); + listen_sockets.push((Ipv4Addr::LOCALHOST, listen_port).into()); let node_id = enr.node_id(); let filter_expected_responses = Arc::new(RwLock::new(HashMap::new())); @@ -58,30 +71,51 @@ async fn build_handler() -> Handler { Socket::new::

(socket_config).await.unwrap() }; - let (_, service_recv) = mpsc::unbounded_channel(); - let (service_send, _) = mpsc::channel(50); - let (_, exit) = oneshot::channel(); - - Handler { - request_retries: config.request_retries, - node_id, - enr: Arc::new(RwLock::new(enr)), - key: Arc::new(RwLock::new(key)), - active_requests: ActiveRequests::new(config.request_timeout), - pending_requests: HashMap::new(), - filter_expected_responses, - sessions: LruTimeCache::new(config.session_timeout, Some(config.session_cache_capacity)), - one_time_sessions: LruTimeCache::new( - Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), - Some(ONE_TIME_SESSION_CACHE_CAPACITY), - ), - active_challenges: HashMapDelay::new(config.request_timeout), - service_recv, - service_send, - listen_sockets, - socket, - exit, - } + let (handler_sender, service_recv) = mpsc::unbounded_channel(); + let (service_send, handler_recv) = mpsc::channel(50); + let (exit_tx, exit) = oneshot::channel(); + + let nat = Nat::new( + &listen_sockets, + &enr, + config.listen_config.ip_mode(), + None, + None, + config.session_cache_capacity, + None, + ); + + ( + Handler { + request_retries: config.request_retries, + node_id, + enr: Arc::new(RwLock::new(enr)), + key: Arc::new(RwLock::new(key)), + active_requests: ActiveRequests::new(config.request_timeout), + pending_requests: HashMap::new(), + filter_expected_responses, + sessions: LruTimeCache::new( + config.session_timeout, + Some(config.session_cache_capacity), + ), + one_time_sessions: LruTimeCache::new( + Duration::from_secs(ONE_TIME_SESSION_TIMEOUT), + Some(ONE_TIME_SESSION_CACHE_CAPACITY), + ), + active_challenges: HashMapDelay::new(config.request_timeout), + service_recv, + service_send, + listen_sockets, + socket, + nat, + exit, + }, + MockService { + tx: handler_sender, + rx: handler_recv, + exit_tx, + }, + ) } macro_rules! arc_rw { @@ -157,9 +191,11 @@ async fn simple_session_message() { loop { if let Some(message) = receiver_recv.recv().await { match message { - HandlerOut::WhoAreYou(wru_ref) => { - let _ = - recv_send.send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone()))); + HandlerOut::RequestEnr(EnrRequestData::WhoAreYou(wru_ref)) => { + let _ = recv_send.send(HandlerIn::EnrResponse( + Some(sender_enr.clone()), + EnrRequestData::WhoAreYou(wru_ref), + )); } HandlerOut::Request(_, request) => { assert_eq!(request, send_message); @@ -273,8 +309,11 @@ async fn multiple_messages() { let receiver = async move { loop { match receiver_handler.recv().await { - Some(HandlerOut::WhoAreYou(wru_ref)) => { - let _ = recv_send.send(HandlerIn::WhoAreYou(wru_ref, Some(sender_enr.clone()))); + Some(HandlerOut::RequestEnr(EnrRequestData::WhoAreYou(wru_ref))) => { + let _ = recv_send.send(HandlerIn::EnrResponse( + Some(sender_enr.clone()), + EnrRequestData::WhoAreYou(wru_ref), + )); } Some(HandlerOut::Request(addr, request)) => { assert_eq!(request, recv_send_message); @@ -419,7 +458,7 @@ async fn test_self_request_ipv6() { #[tokio::test] async fn remove_one_time_session() { - let mut handler = build_handler::().await; + let (mut handler, _) = build_handler::().await; let enr = { let key = CombinedKey::generate_secp256k1(); @@ -453,3 +492,255 @@ async fn remove_one_time_session() { .is_some()); assert_eq!(0, handler.one_time_sessions.len()); } + +#[tokio::test(flavor = "multi_thread")] +async fn nat_hole_punch_relay() { + init(); + + // Relay + let listen_config = ListenConfig::default().with_ipv4(Ipv4Addr::LOCALHOST, 9901); + let (mut handler, mock_service) = + build_handler_with_listen_config::(listen_config).await; + let relay_addr = handler.enr.read().udp4_socket().unwrap().into(); + let relay_node_id = handler.enr.read().node_id(); + + // Initiator + let inr_enr = { + let key = CombinedKey::generate_secp256k1(); + EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9011) + .build(&key) + .unwrap() + }; + let inr_addr = inr_enr.udp4_socket().unwrap().into(); + let inr_node_id = inr_enr.node_id(); + + let initr_node_address = NodeAddress::new(inr_addr, inr_enr.node_id()); + handler + .sessions + .insert(initr_node_address, build_dummy_session()); + + let inr_socket = UdpSocket::bind(inr_addr) + .await + .expect("should bind to initiator socket"); + + // Target + let tgt_enr = { + let key = CombinedKey::generate_secp256k1(); + EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9012) + .build(&key) + .unwrap() + }; + let tgt_addr = tgt_enr.udp4_socket().unwrap().into(); + let tgt_node_id = tgt_enr.node_id(); + + let tgt_node_address = NodeAddress::new(tgt_addr, tgt_enr.node_id()); + handler + .sessions + .insert(tgt_node_address, build_dummy_session()); + + let tgt_socket = UdpSocket::bind(tgt_addr) + .await + .expect("should bind to target socket"); + + // Relay handle + let relay_handle = tokio::spawn(async move { handler.start::().await }); + + // Relay mock service + let tgt_enr_clone = tgt_enr.clone(); + let tx = mock_service.tx; + let mut rx = mock_service.rx; + let mock_service_handle = tokio::spawn(async move { + let service_msg = rx.recv().await.expect("should receive service message"); + match service_msg { + HandlerOut::RequestEnr(EnrRequestData::Nat(relay_init)) => tx + .send(HandlerIn::EnrResponse( + Some(tgt_enr_clone), + EnrRequestData::Nat(relay_init), + )) + .expect("should send message to handler"), + _ => panic!("service message should be 'find hole punch enr'"), + } + }); + + // Initiator handle + let relay_init_notif = + RelayInitNotification::new(inr_enr.clone(), tgt_node_id, MessageNonce::default()); + + let inr_handle = tokio::spawn(async move { + let mut session = build_dummy_session(); + let packet = session + .encrypt_session_message::(inr_node_id, &relay_init_notif.encode()) + .expect("should encrypt notification"); + let encoded_packet = packet.encode::(&relay_node_id); + + inr_socket + .send_to(&encoded_packet, relay_addr) + .await + .expect("should relay init notification to relay") + }); + + // Target handle + let relay_exit = mock_service.exit_tx; + let tgt_handle = tokio::spawn(async move { + let mut buffer = [0; MAX_PACKET_SIZE]; + let res = tgt_socket + .recv_from(&mut buffer) + .await + .expect("should read bytes from socket"); + + drop(relay_exit); + + (res, buffer) + }); + + // Join all handles + let (inr_res, relay_res, tgt_res, mock_service_res) = + tokio::join!(inr_handle, relay_handle, tgt_handle, mock_service_handle); + + inr_res.unwrap(); + relay_res.unwrap(); + mock_service_res.unwrap(); + + let ((length, src), buffer) = tgt_res.unwrap(); + + assert_eq!(src, relay_addr); + + let (packet, aad) = Packet::decode::(&tgt_enr.node_id(), &buffer[..length]) + .expect("should decode packet"); + let Packet { + header, message, .. + } = packet; + let PacketHeader { + kind, + message_nonce, + .. + } = header; + + assert_eq!( + PacketKind::SessionMessage { + src_id: relay_node_id + }, + kind + ); + + let decrypted_message = build_dummy_session() + .decrypt_message(message_nonce, &message, &aad) + .expect("should decrypt message"); + match Message::decode(&decrypted_message).expect("should decode message") { + Message::RelayMsgNotification(relay_msg) => { + let (enr, _) = relay_msg.into(); + assert_eq!(inr_enr, enr) + } + _ => panic!("message should decode to a relay msg notification"), + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn nat_hole_punch_target() { + init(); + + // Target + let listen_config = ListenConfig::default().with_ipv4(Ipv4Addr::LOCALHOST, 9902); + let (mut handler, mock_service) = + build_handler_with_listen_config::(listen_config).await; + let tgt_addr = handler.enr.read().udp4_socket().unwrap().into(); + let tgt_node_id = handler.enr.read().node_id(); + handler.nat.is_behind_nat = Some(true); + + // Relay + let relay_enr = { + let key = CombinedKey::generate_secp256k1(); + EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9022) + .build(&key) + .unwrap() + }; + let relay_addr = relay_enr.udp4_socket().unwrap().into(); + let relay_node_id = relay_enr.node_id(); + + let relay_node_address = NodeAddress::new(relay_addr, relay_node_id); + handler + .sessions + .insert(relay_node_address, build_dummy_session()); + + let relay_socket = UdpSocket::bind(relay_addr) + .await + .expect("should bind to target socket"); + + // Initiator + let inr_enr = { + let key = CombinedKey::generate_secp256k1(); + EnrBuilder::new("v4") + .ip4(Ipv4Addr::LOCALHOST) + .udp4(9021) + .build(&key) + .unwrap() + }; + let inr_addr = inr_enr.udp4_socket().unwrap(); + let inr_node_id = inr_enr.node_id(); + let inr_nonce: MessageNonce = [1; MESSAGE_NONCE_LENGTH]; + + let inr_socket = UdpSocket::bind(inr_addr) + .await + .expect("should bind to initiator socket"); + + // Target handle + let tgt_handle = tokio::spawn(async move { handler.start::().await }); + + // Relay handle + let relay_msg_notif = RelayMsgNotification::new(inr_enr.clone(), inr_nonce); + + let relay_handle = tokio::spawn(async move { + let mut session = build_dummy_session(); + let packet = session + .encrypt_session_message::(relay_node_id, &relay_msg_notif.encode()) + .expect("should encrypt notification"); + let encoded_packet = packet.encode::(&tgt_node_id); + + relay_socket + .send_to(&encoded_packet, tgt_addr) + .await + .expect("should relay init notification to relay") + }); + + // Initiator handle + let target_exit = mock_service.exit_tx; + let inr_handle = tokio::spawn(async move { + let mut buffer = [0; MAX_PACKET_SIZE]; + let res = inr_socket + .recv_from(&mut buffer) + .await + .expect("should read bytes from socket"); + + drop(target_exit); + + (res, buffer) + }); + + // Join all handles + let (tgt_res, relay_res, inr_res) = tokio::join!(tgt_handle, relay_handle, inr_handle); + + tgt_res.unwrap(); + relay_res.unwrap(); + + let ((length, src), buffer) = inr_res.unwrap(); + + assert_eq!(src, tgt_addr); + + let (packet, _aad) = Packet::decode::(&inr_node_id, &buffer[..length]) + .expect("should decode packet"); + let Packet { header, .. } = packet; + let PacketHeader { + kind, + message_nonce, + .. + } = header; + + assert!(kind.is_whoareyou()); + assert_eq!(message_nonce, inr_nonce) +} diff --git a/src/ipmode.rs b/src/ipmode.rs index de66fd195..f2dbe48da 100644 --- a/src/ipmode.rs +++ b/src/ipmode.rs @@ -69,6 +69,17 @@ impl IpMode { } } +/// Copied from the standard library. See +/// The current code is behind the `ip` feature. +pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option { + match ip.octets() { + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => { + Some(std::net::Ipv4Addr::new(a, b, c, d)) + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -230,14 +241,3 @@ mod tests { .test(); } } - -/// Copied from the standard library. See -/// The current code is behind the `ip` feature. -pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option { - match ip.octets() { - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => { - Some(std::net::Ipv4Addr::new(a, b, c, d)) - } - _ => None, - } -} diff --git a/src/kbucket/entry.rs b/src/kbucket/entry.rs index 1e0304048..101d637c2 100644 --- a/src/kbucket/entry.rs +++ b/src/kbucket/entry.rs @@ -25,9 +25,7 @@ //! representing the nodes participating in the Kademlia DHT. pub use super::{ - bucket::{ - AppliedPending, ConnectionState, InsertResult, Node, NodeStatus, MAX_NODES_PER_BUCKET, - }, + bucket::{AppliedPending, ConnectionState, InsertResult, Node, NodeStatus}, key::*, ConnectionDirection, }; @@ -176,8 +174,8 @@ where PendingEntry(EntryRef { bucket, key }) } - /// Returns the value associated with the key. - pub fn value(&mut self) -> &mut TVal { + /// Returns mutable access value associated with the key. + pub fn value_mut(&mut self) -> &mut TVal { self.0 .bucket .pending_mut() diff --git a/src/lru_time_cache.rs b/src/lru_time_cache.rs index fae149e3a..b90336ab2 100644 --- a/src/lru_time_cache.rs +++ b/src/lru_time_cache.rs @@ -5,11 +5,16 @@ use std::{ }; pub struct LruTimeCache { - map: LinkedHashMap, + /// The main map storing the internal values. It stores the time the value was inserted and an + /// optional tag to keep track of individual values. + map: LinkedHashMap, /// The time elements remain in the cache. ttl: Duration, /// The max size of the cache. capacity: usize, + /// Optional count of specific tagged elements. This is used in discv5 for tracking + /// the number of unreachable sessions currently held. + tagged_count: usize, } impl LruTimeCache { @@ -23,22 +28,54 @@ impl LruTimeCache { map: LinkedHashMap::new(), ttl, capacity, + tagged_count: 0, } } - /// Inserts a key-value pair into the cache. + /// Returns the number of elements that are currently tagged in the cache. + pub fn tagged(&self) -> usize { + self.tagged_count + } + + // Insert an untagged key-value pair into the cache. pub fn insert(&mut self, key: K, value: V) { + self.insert_raw(key, value, false); + } + + // Insert a tagged key-value pair into the cache. + #[cfg(test)] + pub fn insert_tagged(&mut self, key: K, value: V) { + self.insert_raw(key, value, true); + } + + /// Inserts a key-value pair into the cache. + pub fn insert_raw(&mut self, key: K, value: V, tagged: bool) { let now = Instant::now(); - self.map.insert(key, (value, now)); + if let Some(old_value) = self.map.insert(key, (value, now, tagged)) { + // If the old value was tagged but the new one isn't, we reduce our count + if !tagged && old_value.2 { + self.tagged_count = self.tagged_count.saturating_sub(1); + } else if tagged && !old_value.2 { + // Else if the new value is tagged and the old wasn't tagged increment the count + self.tagged_count += 1; + } + } else if tagged { + // No previous value, increment the tagged count + self.tagged_count += 1; + } if self.map.len() > self.capacity { - self.map.pop_front(); + if let Some((_, value)) = self.map.pop_front() { + if value.2 { + // We have removed a tagged element + self.tagged_count = self.tagged_count.saturating_sub(1); + } + } } } /// Retrieves a reference to the value stored under `key`, or `None` if the key doesn't exist. /// Also removes expired elements and updates the time. - #[allow(dead_code)] pub fn get(&mut self, key: &K) -> Option<&V> { self.get_mut(key).map(|value| &*value) } @@ -61,16 +98,14 @@ impl LruTimeCache { /// Returns a reference to the value with the given `key`, if present and not expired, without /// updating the timestamp. - #[allow(dead_code)] pub fn peek(&self, key: &K) -> Option<&V> { - if let Some((value, time)) = self.map.get(key) { + if let Some((value, time, _)) = self.map.get(key) { return if *time + self.ttl >= Instant::now() { Some(value) } else { None }; } - None } @@ -83,14 +118,20 @@ impl LruTimeCache { /// Removes a key-value pair from the cache, returning the value at the key if the key /// was previously in the map. pub fn remove(&mut self, key: &K) -> Option { - self.map.remove(key).map(|v| v.0) + let value = self.map.remove(key)?; + + // This element was tagged, reduce the count + if value.2 { + self.tagged_count = self.tagged_count.saturating_sub(1); + } + Some(value.0) } /// Removes expired items from the cache. fn remove_expired_values(&mut self, now: Instant) { let mut expired_keys = vec![]; - for (key, (_, time)) in self.map.iter_mut() { + for (key, (_, time, _)) in self.map.iter_mut() { if *time + self.ttl >= now { break; } @@ -98,9 +139,18 @@ impl LruTimeCache { } for k in expired_keys { - self.map.remove(&k); + if let Some(v) = self.map.remove(&k) { + // This key was tagged, reduce the count + if v.2 { + self.tagged_count = self.tagged_count.saturating_sub(1); + } + } } } + + pub fn clear(&mut self) { + self.map.clear() + } } #[cfg(test)] @@ -135,6 +185,30 @@ mod tests { assert_eq!(Some(&30), cache.get(&3)); } + #[test] + fn tagging() { + let mut cache = LruTimeCache::new(Duration::from_secs(10), Some(2)); + + cache.insert_tagged(1, 10); + cache.insert(2, 20); + assert_eq!(2, cache.len()); + assert_eq!(1, cache.tagged()); + + cache.insert_tagged(3, 30); + assert_eq!(2, cache.len()); + assert_eq!(1, cache.tagged()); + assert_eq!(Some(&20), cache.get(&2)); + assert_eq!(Some(&30), cache.get(&3)); + + cache.insert_tagged(2, 30); + assert_eq!(2, cache.tagged()); + + cache.insert(4, 30); + assert_eq!(1, cache.tagged()); + cache.insert(5, 30); + assert_eq!(0, cache.tagged()); + } + #[test] fn get() { let mut cache = LruTimeCache::new(Duration::from_secs(10), Some(2)); diff --git a/src/node_info.rs b/src/node_info.rs index 6224fe12b..980e6b0ec 100644 --- a/src/node_info.rs +++ b/src/node_info.rs @@ -1,5 +1,6 @@ use super::*; use crate::Enr; +use derive_more::Display; use enr::{CombinedPublicKey, NodeId}; use std::net::SocketAddr; @@ -10,7 +11,7 @@ use libp2p_identity::{KeyType, PublicKey}; /// This type relaxes the requirement of having an ENR to connect to a node, to allow for unsigned /// connection types, such as multiaddrs. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct NodeContact { /// Key to use for communications with this node. public_key: CombinedPublicKey, @@ -148,7 +149,8 @@ impl std::fmt::Display for NodeContact { } /// A representation of an unsigned contactable node. -#[derive(PartialEq, Hash, Eq, Clone, Debug)] +#[derive(PartialEq, Hash, Eq, Clone, Debug, Display)] +#[display(fmt = "Node: {node_id}, addr: {socket_addr}")] pub struct NodeAddress { /// The destination socket address. pub socket_addr: SocketAddr, @@ -184,9 +186,3 @@ impl NodeAddress { } } } - -impl std::fmt::Display for NodeAddress { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Node: {}, addr: {:?}", self.node_id, self.socket_addr) - } -} diff --git a/src/packet/mod.rs b/src/packet/mod.rs index f071263bd..c59e72e08 100644 --- a/src/packet/mod.rs +++ b/src/packet/mod.rs @@ -24,7 +24,7 @@ pub const IV_LENGTH: usize = 16; /// The length of the static header. (6 byte protocol id, 2 bytes version, 1 byte kind, 12 byte /// message nonce and a 2 byte authdata-size). pub const STATIC_HEADER_LENGTH: usize = 23; -/// The message nonce length (in bytes). +/// The message nonce length (in bytes). This must be at least 4 bytes. pub const MESSAGE_NONCE_LENGTH: usize = 12; /// The Id nonce length (in bytes). pub const ID_NONCE_LENGTH: usize = 16; @@ -36,7 +36,7 @@ impl ProtocolIdentity for DefaultProtocolId { const PROTOCOL_VERSION_BYTES: [u8; 2] = 0x0001_u16.to_be_bytes(); } -pub trait ProtocolIdentity { +pub trait ProtocolIdentity: Sync + Send { const PROTOCOL_ID_BYTES: [u8; 6]; const PROTOCOL_VERSION_BYTES: [u8; 2]; } @@ -141,6 +141,14 @@ pub enum PacketKind { /// The ENR record of the node if the WHOAREYOU request is out-dated. enr_record: Option, }, + /// A session message is a notification, hence it differs from the [`PacketKind::Message`] in + /// the way it handles sessions since notifications don't trigger responses, a session + /// message packet doesn't trigger a WHOAREYOU response. If a session doesn't exist to + /// decrypt or encrypt a notification, it is dropped. + SessionMessage { + /// The sending NodeId. + src_id: NodeId, + }, } impl From<&PacketKind> for u8 { @@ -149,6 +157,7 @@ impl From<&PacketKind> for u8 { PacketKind::Message { .. } => 0, PacketKind::WhoAreYou { .. } => 1, PacketKind::Handshake { .. } => 2, + PacketKind::SessionMessage { .. } => 3, } } } @@ -157,7 +166,9 @@ impl PacketKind { /// Encodes the packet type into its corresponding auth_data. pub fn encode(&self) -> Vec { match self { - PacketKind::Message { src_id } => src_id.raw().to_vec(), + PacketKind::Message { src_id } | PacketKind::SessionMessage { src_id } => { + src_id.raw().to_vec() + } PacketKind::WhoAreYou { id_nonce, enr_seq } => { let mut auth_data = Vec::with_capacity(24); auth_data.extend_from_slice(id_nonce); @@ -273,6 +284,16 @@ impl PacketKind { enr_record, }) } + 3 => { + // Decoding a SessionMessage packet + // This should only contain a 32 byte NodeId. + if auth_data.len() != 32 { + return Err(PacketError::InvalidAuthDataSize); + } + + let src_id = NodeId::parse(auth_data).map_err(|_| PacketError::InvalidNodeId)?; + Ok(PacketKind::SessionMessage { src_id }) + } _ => Err(PacketError::UnknownPacket), } } @@ -362,7 +383,9 @@ impl Packet { pub fn is_whoareyou(&self) -> bool { match &self.header.kind { PacketKind::WhoAreYou { .. } => true, - PacketKind::Message { .. } | PacketKind::Handshake { .. } => false, + PacketKind::Message { .. } + | PacketKind::Handshake { .. } + | PacketKind::SessionMessage { .. } => false, } } @@ -370,7 +393,7 @@ impl Packet { /// src_id in this case. pub fn src_id(&self) -> Option { match self.header.kind { - PacketKind::Message { src_id } => Some(src_id), + PacketKind::Message { src_id } | PacketKind::SessionMessage { src_id } => Some(src_id), PacketKind::WhoAreYou { .. } => None, PacketKind::Handshake { src_id, .. } => Some(src_id), } @@ -421,7 +444,7 @@ impl Packet { /// /// This also returns the authenticated data for further decryption in the handler. pub fn decode( - src_id: &NodeId, + dst_id: &NodeId, data: &[u8], ) -> Result<(Self, Vec), PacketError> { if data.len() > MAX_PACKET_SIZE { @@ -439,7 +462,7 @@ impl Packet { * This was split into its own library, but brought back to allow re-use of the cipher when * performing the decryption */ - let key = GenericArray::clone_from_slice(&src_id.raw()[..16]); + let key = GenericArray::clone_from_slice(&dst_id.raw()[..16]); let nonce = GenericArray::clone_from_slice(&iv); let mut cipher = Aes128Ctr::new(&key, &nonce); @@ -565,6 +588,9 @@ impl std::fmt::Display for PacketKind { hex::encode(ephem_pubkey), enr_record ), + PacketKind::SessionMessage { src_id } => { + write!(f, "SessionMessage {{ src_id: {src_id} }}") + } } } } diff --git a/src/rpc.rs b/src/rpc.rs index c5212c99c..974cb8d4b 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -1,303 +1,88 @@ -use enr::{CombinedKey, Enr}; -use rlp::{DecoderError, RlpStream}; -use std::net::{IpAddr, Ipv6Addr}; -use tracing::{debug, warn}; - -/// Type to manage the request IDs. -#[derive(Debug, Clone, PartialEq, Hash, Eq)] -pub struct RequestId(pub Vec); - -impl From for Vec { - fn from(id: RequestId) -> Self { - id.0 - } +use derive_more::{Display, From}; +use rlp::{DecoderError, Rlp}; +use std::convert::{TryFrom, TryInto}; + +mod notification; +mod request; +mod response; + +pub use notification::{RelayInitNotification, RelayMsgNotification}; +pub use request::{Request, RequestBody, RequestId}; +pub use response::{Response, ResponseBody}; + +/// Message type IDs. +#[derive(Debug)] +#[repr(u8)] +pub enum MessageType { + Ping = 1, + Pong = 2, + FindNode = 3, + Nodes = 4, + TalkReq = 5, + TalkResp = 6, + RelayInit = 7, + RelayMsg = 8, } -impl RequestId { - /// Decodes the ID from a raw bytes. - pub fn decode(data: Vec) -> Result { - if data.len() > 8 { - return Err(DecoderError::Custom("Invalid ID length")); +impl TryFrom for MessageType { + type Error = DecoderError; + fn try_from(byte: u8) -> Result { + match byte { + 1 => Ok(MessageType::Ping), + 2 => Ok(MessageType::Pong), + 3 => Ok(MessageType::FindNode), + 4 => Ok(MessageType::Nodes), + 5 => Ok(MessageType::TalkReq), + 6 => Ok(MessageType::TalkResp), + 7 => Ok(MessageType::RelayInit), + 8 => Ok(MessageType::RelayMsg), + _ => Err(DecoderError::Custom("Unknown RPC message type")), } - Ok(RequestId(data)) - } - - pub fn random() -> Self { - let rand: u64 = rand::random(); - RequestId(rand.to_be_bytes().to_vec()) } +} - pub fn as_bytes(&self) -> &[u8] { - &self.0 - } +/// The payload of message containers SessionMessage, Message or Handshake type. +pub trait Payload +where + Self: Sized, +{ + /// Matches a payload type to its message type id. + fn msg_type(&self) -> u8; + /// Encodes a message to RLP-encoded bytes. + fn encode(self) -> Vec; + /// Decodes RLP-encoded bytes into a message. + fn decode(msg_type: u8, rlp: &Rlp<'_>) -> Result; } -#[derive(Debug, Clone, PartialEq, Eq)] -/// A combined type representing requests and responses. +#[derive(Debug, Clone, PartialEq, Eq, Display, From)] +/// A combined type representing the messages which are the payloads of packets. pub enum Message { /// A request, which contains its [`RequestId`]. + #[display(fmt = "{_0}")] Request(Request), + /// A Response, which contains the [`RequestId`] of its associated request. + #[display(fmt = "{_0}")] Response(Response), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -/// A request sent between nodes. -pub struct Request { - /// The [`RequestId`] of the request. - pub id: RequestId, - /// The body of the request. - pub body: RequestBody, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -/// A response sent in response to a [`Request`] -pub struct Response { - /// The [`RequestId`] of the request that triggered this response. - pub id: RequestId, - /// The body of this response. - pub body: ResponseBody, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum RequestBody { - /// A PING request. - Ping { - /// Our current ENR sequence number. - enr_seq: u64, - }, - /// A FINDNODE request. - FindNode { - /// The distance(s) of peers we expect to be returned in the response. - distances: Vec, - }, - /// A Talk request. - Talk { - /// The protocol requesting. - protocol: Vec, - /// The request. - request: Vec, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ResponseBody { - /// A PONG response. - Pong { - /// The current ENR sequence number of the responder. - enr_seq: u64, - /// Our external IP address as observed by the responder. - ip: IpAddr, - /// Our external UDP port as observed by the responder. - port: u16, - }, - /// A NODES response. - Nodes { - /// The total number of responses that make up this response. - total: u64, - /// A list of ENR's returned by the responder. - nodes: Vec>, - }, - /// The TALK response. - Talk { - /// The response for the talk. - response: Vec, - }, -} - -impl Request { - pub fn msg_type(&self) -> u8 { - match self.body { - RequestBody::Ping { .. } => 1, - RequestBody::FindNode { .. } => 3, - RequestBody::Talk { .. } => 5, - } - } - - /// Encodes a Message to RLP-encoded bytes. - pub fn encode(self) -> Vec { - let mut buf = Vec::with_capacity(10); - let msg_type = self.msg_type(); - buf.push(msg_type); - let id = &self.id; - match self.body { - RequestBody::Ping { enr_seq } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.append(&enr_seq); - buf.extend_from_slice(&s.out()); - buf - } - RequestBody::FindNode { distances } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.begin_list(distances.len()); - for distance in distances { - s.append(&distance); - } - buf.extend_from_slice(&s.out()); - buf - } - RequestBody::Talk { protocol, request } => { - let mut s = RlpStream::new(); - s.begin_list(3); - s.append(&id.as_bytes()); - s.append(&protocol); - s.append(&request); - buf.extend_from_slice(&s.out()); - buf - } - } - } -} - -impl Response { - pub fn msg_type(&self) -> u8 { - match &self.body { - ResponseBody::Pong { .. } => 2, - ResponseBody::Nodes { .. } => 4, - ResponseBody::Talk { .. } => 6, - } - } - - /// Determines if the response is a valid response to the given request. - pub fn match_request(&self, req: &RequestBody) -> bool { - match self.body { - ResponseBody::Pong { .. } => matches!(req, RequestBody::Ping { .. }), - ResponseBody::Nodes { .. } => { - matches!(req, RequestBody::FindNode { .. }) - } - ResponseBody::Talk { .. } => matches!(req, RequestBody::Talk { .. }), - } - } - - /// Encodes a Message to RLP-encoded bytes. - pub fn encode(self) -> Vec { - let mut buf = Vec::with_capacity(10); - let msg_type = self.msg_type(); - buf.push(msg_type); - let id = &self.id; - match self.body { - ResponseBody::Pong { enr_seq, ip, port } => { - let mut s = RlpStream::new(); - s.begin_list(4); - s.append(&id.as_bytes()); - s.append(&enr_seq); - match ip { - IpAddr::V4(addr) => s.append(&(&addr.octets() as &[u8])), - IpAddr::V6(addr) => s.append(&(&addr.octets() as &[u8])), - }; - s.append(&port); - buf.extend_from_slice(&s.out()); - buf - } - ResponseBody::Nodes { total, nodes } => { - let mut s = RlpStream::new(); - s.begin_list(3); - s.append(&id.as_bytes()); - s.append(&total); - - if nodes.is_empty() { - s.begin_list(0); - } else { - s.begin_list(nodes.len()); - for node in nodes { - s.append(&node); - } - } - buf.extend_from_slice(&s.out()); - buf - } - ResponseBody::Talk { response } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.append(&response); - buf.extend_from_slice(&s.out()); - buf - } - } - } -} - -impl std::fmt::Display for RequestId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", hex::encode(&self.0)) - } -} - -impl std::fmt::Display for Message { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Message::Request(request) => write!(f, "{request}"), - Message::Response(response) => write!(f, "{response}"), - } - } -} - -impl std::fmt::Display for Response { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Response: id: {}: {}", self.id, self.body) - } -} - -impl std::fmt::Display for ResponseBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ResponseBody::Pong { enr_seq, ip, port } => { - write!(f, "PONG: Enr-seq: {enr_seq}, Ip: {ip:?}, Port: {port}") - } - ResponseBody::Nodes { total, nodes } => { - write!(f, "NODES: total: {total}, Nodes: [")?; - let mut first = true; - for id in nodes { - if !first { - write!(f, ", {id}")?; - } else { - write!(f, "{id}")?; - } - first = false; - } - - write!(f, "]") - } - ResponseBody::Talk { response } => { - write!(f, "Response: Response {}", hex::encode(response)) - } - } - } -} -impl std::fmt::Display for Request { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Request: id: {}: {}", self.id, self.body) - } + /// Unicast notifications. + /// + /// A [`RelayInitNotification`]. + #[display(fmt = "{_0}")] + RelayInitNotification(RelayInitNotification), + /// A [`RelayMsgNotification`]. + #[display(fmt = "{_0}")] + RelayMsgNotification(RelayMsgNotification), } -impl std::fmt::Display for RequestBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - RequestBody::Ping { enr_seq } => write!(f, "PING: enr_seq: {enr_seq}"), - RequestBody::FindNode { distances } => { - write!(f, "FINDNODE Request: distance: {distances:?}") - } - RequestBody::Talk { protocol, request } => write!( - f, - "TALK: protocol: {}, request: {}", - hex::encode(protocol), - hex::encode(request) - ), - } - } -} #[allow(dead_code)] impl Message { pub fn encode(self) -> Vec { match self { Self::Request(request) => request.encode(), Self::Response(response) => response.encode(), + Self::RelayInitNotification(notif) => notif.encode(), + Self::RelayMsgNotification(notif) => notif.encode(), } } @@ -305,183 +90,38 @@ impl Message { if data.len() < 3 { return Err(DecoderError::RlpIsTooShort); } - let msg_type = data[0]; let rlp = rlp::Rlp::new(&data[1..]); - let list_len = rlp.item_count().and_then(|size| { - if size < 2 { - Err(DecoderError::RlpIncorrectListLen) - } else { - Ok(size) - } - })?; - - let id = RequestId::decode(rlp.val_at::>(0)?)?; - - let message = match msg_type { - 1 => { - // PingRequest - if list_len != 2 { - debug!( - "Ping Request has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - Message::Request(Request { - id, - body: RequestBody::Ping { - enr_seq: rlp.val_at::(1)?, - }, - }) - } - 2 => { - // PingResponse - if list_len != 4 { - debug!( - "Ping Response has an invalid RLP list length. Expected 4, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let ip_bytes = rlp.val_at::>(2)?; - let ip = match ip_bytes.len() { - 4 => { - let mut ip = [0u8; 4]; - ip.copy_from_slice(&ip_bytes); - IpAddr::from(ip) - } - 16 => { - let mut ip = [0u8; 16]; - ip.copy_from_slice(&ip_bytes); - let ipv6 = Ipv6Addr::from(ip); - - if ipv6.is_loopback() { - // Checking if loopback address since IPv6Addr::to_ipv4 returns - // IPv4 address for IPv6 loopback address. - IpAddr::V6(ipv6) - } else if let Some(ipv4) = ipv6.to_ipv4() { - // If the ipv6 is ipv4 compatible/mapped, simply return the ipv4. - IpAddr::V4(ipv4) - } else { - IpAddr::V6(ipv6) - } - } - _ => { - debug!("Ping Response has incorrect byte length for IP"); - return Err(DecoderError::RlpIncorrectListLen); - } - }; - let port = rlp.val_at::(3)?; - Message::Response(Response { - id, - body: ResponseBody::Pong { - enr_seq: rlp.val_at::(1)?, - ip, - port, - }, - }) - } - 3 => { - // FindNodeRequest - if list_len != 2 { - debug!( - "FindNode Request has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let distances = rlp.list_at::(1)?; - - for distance in distances.iter() { - if distance > &256u64 { - warn!( - "Rejected FindNode request asking for unknown distance {}, maximum 256", - distance - ); - return Err(DecoderError::Custom("FINDNODE request distance invalid")); - } - } - - Message::Request(Request { - id, - body: RequestBody::FindNode { distances }, - }) + match msg_type.try_into()? { + MessageType::Ping | MessageType::FindNode | MessageType::TalkReq => { + Ok(Request::decode(msg_type, &rlp)?.into()) } - 4 => { - // NodesResponse - if list_len != 3 { - debug!( - "Nodes Response has an invalid RLP list length. Expected 3, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - - let nodes = { - let enr_list_rlp = rlp.at(2)?; - if enr_list_rlp.is_empty() { - // no records - vec![] - } else { - enr_list_rlp.as_list::>()? - } - }; - Message::Response(Response { - id, - body: ResponseBody::Nodes { - total: rlp.val_at::(1)?, - nodes, - }, - }) + MessageType::Pong | MessageType::Nodes | MessageType::TalkResp => { + Ok(Response::decode(msg_type, &rlp)?.into()) } - 5 => { - // Talk Request - if list_len != 3 { - debug!( - "Talk Request has an invalid RLP list length. Expected 3, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let protocol = rlp.val_at::>(1)?; - let request = rlp.val_at::>(2)?; - Message::Request(Request { - id, - body: RequestBody::Talk { protocol, request }, - }) - } - 6 => { - // Talk Response - if list_len != 2 { - debug!( - "Talk Response has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let response = rlp.val_at::>(1)?; - Message::Response(Response { - id, - body: ResponseBody::Talk { response }, - }) - } - _ => { - return Err(DecoderError::Custom("Unknown RPC message type")); - } - }; + MessageType::RelayInit => Ok(RelayInitNotification::decode(msg_type, &rlp)?.into()), + MessageType::RelayMsg => Ok(RelayMsgNotification::decode(msg_type, &rlp)?.into()), + } + } - Ok(message) + pub fn msg_type(&self) -> String { + match self { + Self::Request(r) => format!("request type {}", r.msg_type()), + Self::Response(r) => format!("response type {}", r.msg_type()), + Self::RelayInitNotification(n) => format!("notification type {}", n.msg_type()), + Self::RelayMsgNotification(n) => format!("notification type {}", n.msg_type()), + } } } #[cfg(test)] mod tests { use super::*; - use enr::EnrBuilder; - use std::net::Ipv4Addr; + use crate::packet::MESSAGE_NONCE_LENGTH; + use enr::{CombinedKey, Enr, EnrBuilder}; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[test] fn ref_test_encode_request_ping() { @@ -746,4 +386,67 @@ mod tests { assert_eq!(request, decoded); } + + #[test] + fn encode_decode_talk_request() { + let id = RequestId(vec![1]); + let request = Message::Request(Request { + id, + body: RequestBody::TalkReq { + protocol: vec![17u8; 32], + request: vec![1, 2, 3], + }, + }); + + let encoded = request.clone().encode(); + let decoded = Message::decode(&encoded).unwrap(); + + assert_eq!(request, decoded); + } + + #[test] + fn test_encode_decode_relay_init() { + // generate a new enr key for the initiator + let enr_key = CombinedKey::generate_secp256k1(); + // construct the initiator's ENR + let inr_enr = EnrBuilder::new("v4").build(&enr_key).unwrap(); + + // generate a new enr key for the target + let enr_key_tgt = CombinedKey::generate_secp256k1(); + // construct the target's ENR + let tgt_enr = EnrBuilder::new("v4").build(&enr_key_tgt).unwrap(); + let tgt_node_id = tgt_enr.node_id(); + + let nonce_bytes = hex::decode("47644922f5d6e951051051ac").unwrap(); + let mut nonce = [0u8; MESSAGE_NONCE_LENGTH]; + nonce[MESSAGE_NONCE_LENGTH - nonce_bytes.len()..].copy_from_slice(&nonce_bytes); + + let notif = RelayInitNotification::new(inr_enr, tgt_node_id, nonce); + let msg = Message::RelayInitNotification(notif); + + let encoded_msg = msg.clone().encode(); + let decoded_msg = Message::decode(&encoded_msg).expect("Should decode"); + + assert_eq!(msg, decoded_msg); + } + + #[test] + fn test_enocde_decode_relay_msg() { + // generate a new enr key for the initiator + let enr_key = CombinedKey::generate_secp256k1(); + // construct the initiator's ENR + let inr_enr = EnrBuilder::new("v4").build(&enr_key).unwrap(); + + let nonce_bytes = hex::decode("9951051051aceb").unwrap(); + let mut nonce = [0u8; MESSAGE_NONCE_LENGTH]; + nonce[MESSAGE_NONCE_LENGTH - nonce_bytes.len()..].copy_from_slice(&nonce_bytes); + + let notif = RelayMsgNotification::new(inr_enr, nonce); + let msg = Message::RelayMsgNotification(notif); + + let encoded_msg = msg.clone().encode(); + let decoded_msg = Message::decode(&encoded_msg).expect("Should decode"); + + assert_eq!(msg, decoded_msg); + } } diff --git a/src/rpc/notification.rs b/src/rpc/notification.rs new file mode 100644 index 000000000..cfae7d6b0 --- /dev/null +++ b/src/rpc/notification.rs @@ -0,0 +1,160 @@ +use super::{MessageType, Payload}; +use crate::{ + packet::{MessageNonce, MESSAGE_NONCE_LENGTH}, + Enr, +}; +use derive_more::Display; +use enr::NodeId; +use rlp::{DecoderError, Rlp, RlpStream}; + +/// Nonce of request that triggered the initiation of this hole punching attempt. +type NonceOfTimedOutMessage = MessageNonce; +/// Node id length in bytes. +pub const NODE_ID_LENGTH: usize = 32; + +/// Unicast notifications [`RelayInitNotification`] and [`RelayMsgNotification`] sent over discv5. + +/// A notification to initialise a one-shot relay circuit for hole-punching. +#[derive(Debug, Display, PartialEq, Eq, Clone)] +#[display(fmt = "Notification: RelayInit: Initiator: {_0}, Target: {_1}, Nonce: {_2:?}")] +pub struct RelayInitNotification(Enr, NodeId, NonceOfTimedOutMessage); + +impl RelayInitNotification { + pub fn new( + initr_enr: Enr, + tgt_node_id: NodeId, + timed_out_msg_nonce: NonceOfTimedOutMessage, + ) -> Self { + Self(initr_enr, tgt_node_id, timed_out_msg_nonce) + } + + pub fn initiator_enr(&self) -> &Enr { + &self.0 + } + + pub fn target_node_id(&self) -> NodeId { + self.1 + } +} + +impl Payload for RelayInitNotification { + /// Matches a notification type to its message type id. + fn msg_type(&self) -> u8 { + MessageType::RelayInit as u8 + } + + /// Encodes a notification message to RLP-encoded bytes. + fn encode(self) -> Vec { + let mut buf = Vec::with_capacity(100); + let msg_type = self.msg_type(); + buf.push(msg_type); + let mut s = RlpStream::new(); + let Self(initiator, target, nonce) = self; + + s.begin_list(3); + s.append(&initiator); + s.append(&(&target.raw() as &[u8])); + s.append(&(&nonce as &[u8])); + + buf.extend_from_slice(&s.out()); + buf + } + + /// Decodes RLP-encoded bytes into a notification message. + fn decode(_msg_type: u8, rlp: &Rlp<'_>) -> Result { + if rlp.item_count()? != 3 { + return Err(DecoderError::RlpIncorrectListLen); + } + let initiator = rlp.val_at::(0)?; + + let tgt_bytes = rlp.val_at::>(1)?; + if tgt_bytes.len() > NODE_ID_LENGTH { + return Err(DecoderError::RlpIsTooBig); + } + let mut tgt = [0u8; NODE_ID_LENGTH]; + tgt[NODE_ID_LENGTH - tgt_bytes.len()..].copy_from_slice(&tgt_bytes); + let tgt = NodeId::from(tgt); + + let nonce = { + let bytes = rlp.val_at::>(2)?; + if bytes.len() > MESSAGE_NONCE_LENGTH { + return Err(DecoderError::RlpIsTooBig); + } + let mut buf = [0u8; MESSAGE_NONCE_LENGTH]; + buf[MESSAGE_NONCE_LENGTH - bytes.len()..].copy_from_slice(&bytes); + buf + }; + + Ok(Self(initiator, tgt, nonce)) + } +} + +impl From for (Enr, NodeId, NonceOfTimedOutMessage) { + fn from(value: RelayInitNotification) -> Self { + let RelayInitNotification(initr_enr, tgt_node_id, timed_out_msg_nonce) = value; + + (initr_enr, tgt_node_id, timed_out_msg_nonce) + } +} + +/// The notification relayed to target of hole punch attempt. +#[derive(Debug, Display, PartialEq, Eq, Clone)] +#[display(fmt = "Notification: RelayMsg: Initiator: {_0}, Nonce: {_1:?}")] +pub struct RelayMsgNotification(Enr, NonceOfTimedOutMessage); + +impl RelayMsgNotification { + pub fn new(initr_enr: Enr, timed_out_msg_nonce: NonceOfTimedOutMessage) -> Self { + RelayMsgNotification(initr_enr, timed_out_msg_nonce) + } +} + +impl Payload for RelayMsgNotification { + /// Matches a notification type to its message type id. + fn msg_type(&self) -> u8 { + MessageType::RelayMsg as u8 + } + + /// Encodes a notification message to RLP-encoded bytes. + fn encode(self) -> Vec { + let mut buf = Vec::with_capacity(100); + let msg_type = self.msg_type(); + buf.push(msg_type); + let mut s = RlpStream::new(); + let Self(initiator, nonce) = self; + + s.begin_list(2); + s.append(&initiator); + s.append(&(&nonce as &[u8])); + + buf.extend_from_slice(&s.out()); + buf + } + + /// Decodes RLP-encoded bytes into a notification message. + fn decode(_msg_type: u8, rlp: &Rlp<'_>) -> Result { + if rlp.item_count()? != 2 { + return Err(DecoderError::RlpIncorrectListLen); + } + let initiator = rlp.val_at::(0)?; + + let nonce = { + let bytes = rlp.val_at::>(1)?; + if bytes.len() > MESSAGE_NONCE_LENGTH { + return Err(DecoderError::RlpIsTooBig); + } + let mut buf = [0u8; MESSAGE_NONCE_LENGTH]; + buf[MESSAGE_NONCE_LENGTH - bytes.len()..].copy_from_slice(&bytes); + buf + }; + + Ok(Self(initiator, nonce)) + } +} + +impl From for (Enr, NonceOfTimedOutMessage) { + fn from(value: RelayMsgNotification) -> Self { + let RelayMsgNotification(initr_enr, timed_out_msg_nonce) = value; + + (initr_enr, timed_out_msg_nonce) + } +} diff --git a/src/rpc/request.rs b/src/rpc/request.rs new file mode 100644 index 000000000..26cd50983 --- /dev/null +++ b/src/rpc/request.rs @@ -0,0 +1,205 @@ +use super::{MessageType, Payload}; +use derive_more::Display; +use rlp::{DecoderError, Rlp, RlpStream}; +use std::convert::TryInto; +use tracing::{debug, warn}; + +/// A request sent between nodes. +#[derive(Debug, Clone, PartialEq, Eq, Display)] +#[display(fmt = "Request: id: {id}: {body}")] +pub struct Request { + /// The [`RequestId`] of the request. + pub id: RequestId, + /// The body of the request. + pub body: RequestBody, +} + +impl Payload for Request { + /// Matches a request type to its message type id. + fn msg_type(&self) -> u8 { + match self.body { + RequestBody::Ping { .. } => MessageType::Ping as u8, + RequestBody::FindNode { .. } => MessageType::FindNode as u8, + RequestBody::TalkReq { .. } => MessageType::TalkReq as u8, + } + } + + /// Encodes a request message to RLP-encoded bytes. + fn encode(self) -> Vec { + let mut buf = Vec::with_capacity(10); + let msg_type = self.msg_type(); + buf.push(msg_type); + let id = &self.id; + match self.body { + RequestBody::Ping { enr_seq } => { + let mut s = RlpStream::new(); + s.begin_list(2); + s.append(&id.as_bytes()); + s.append(&enr_seq); + buf.extend_from_slice(&s.out()); + buf + } + RequestBody::FindNode { distances } => { + let mut s = RlpStream::new(); + s.begin_list(2); + s.append(&id.as_bytes()); + s.begin_list(distances.len()); + for distance in distances { + s.append(&distance); + } + buf.extend_from_slice(&s.out()); + buf + } + RequestBody::TalkReq { protocol, request } => { + let mut s = RlpStream::new(); + s.begin_list(3); + s.append(&id.as_bytes()); + s.append(&protocol); + s.append(&request); + buf.extend_from_slice(&s.out()); + buf + } + } + } + + /// Decodes RLP-encoded bytes into a request message. + fn decode(msg_type: u8, rlp: &Rlp<'_>) -> Result { + let list_len = rlp.item_count()?; + let id = RequestId::decode(rlp.val_at::>(0)?)?; + let message = match msg_type.try_into()? { + MessageType::Ping => { + // Ping Request + if list_len != 2 { + debug!( + "Ping Request has an invalid RLP list length. Expected 2, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + Self { + id, + body: RequestBody::Ping { + enr_seq: rlp.val_at::(1)?, + }, + } + } + MessageType::FindNode => { + // FindNode Request + if list_len != 2 { + debug!( + "FindNode Request has an invalid RLP list length. Expected 2, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + let distances = rlp.list_at::(1)?; + + for distance in distances.iter() { + if distance > &256u64 { + warn!( + "Rejected FindNode request asking for unknown distance {}, maximum 256", + distance + ); + return Err(DecoderError::Custom("FINDNODE request distance invalid")); + } + } + + Self { + id, + body: RequestBody::FindNode { distances }, + } + } + MessageType::TalkReq => { + // Talk Request + if list_len != 3 { + debug!( + "Talk Request has an invalid RLP list length. Expected 3, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + let protocol = rlp.val_at::>(1)?; + let request = rlp.val_at::>(2)?; + Self { + id, + body: RequestBody::TalkReq { protocol, request }, + } + } + _ => unreachable!("Implementation does not adhere to wire protocol"), + }; + Ok(message) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RequestBody { + /// A PING request. + Ping { + /// Our current ENR sequence number. + enr_seq: u64, + }, + /// A FINDNODE request. + FindNode { + /// The distance(s) of peers we expect to be returned in the response. + distances: Vec, + }, + /// A Talk request. + TalkReq { + /// The protocol requesting. + protocol: Vec, + /// The request. + request: Vec, + }, +} + +impl std::fmt::Display for RequestBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RequestBody::Ping { enr_seq } => write!(f, "PING: enr_seq: {enr_seq}"), + RequestBody::FindNode { distances } => { + write!(f, "FINDNODE Request: distance: {distances:?}") + } + RequestBody::TalkReq { protocol, request } => write!( + f, + "TALK: protocol: {}, request: {}", + hex::encode(protocol), + hex::encode(request) + ), + } + } +} + +/// Type to manage the request IDs. +#[derive(Debug, Clone, PartialEq, Hash, Eq)] +pub struct RequestId(pub Vec); + +impl From> for RequestId { + fn from(v: Vec) -> Self { + RequestId(v) + } +} + +impl RequestId { + /// Decodes the ID from a raw bytes. + pub fn decode(data: Vec) -> Result { + if data.len() > 8 { + return Err(DecoderError::Custom("Invalid ID length")); + } + Ok(RequestId(data)) + } + + pub fn random() -> Self { + let rand: u64 = rand::random(); + RequestId(rand.to_be_bytes().to_vec()) + } + + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", hex::encode(&self.0)) + } +} diff --git a/src/rpc/response.rs b/src/rpc/response.rs new file mode 100644 index 000000000..e0978e6fa --- /dev/null +++ b/src/rpc/response.rs @@ -0,0 +1,241 @@ +use super::{MessageType, Payload, RequestBody, RequestId}; +use crate::Enr; +use derive_more::Display; +use rlp::{DecoderError, Rlp, RlpStream}; +use std::{ + convert::TryInto, + net::{IpAddr, Ipv6Addr}, +}; +use tracing::debug; + +/// A response sent in response to a [`super::Request`] +#[derive(Debug, Clone, PartialEq, Eq, Display)] +#[display(fmt = "Response: id: {id}: {body}")] +pub struct Response { + /// The [`RequestId`] of the request that triggered this response. + pub id: RequestId, + /// The body of this response. + pub body: ResponseBody, +} + +impl Payload for Response { + /// Matches a response type to its message type id. + fn msg_type(&self) -> u8 { + match &self.body { + ResponseBody::Pong { .. } => MessageType::Pong as u8, + ResponseBody::Nodes { .. } => MessageType::Nodes as u8, + ResponseBody::TalkResp { .. } => MessageType::TalkResp as u8, + } + } + + /// Encodes a response message to RLP-encoded bytes. + fn encode(self) -> Vec { + let mut buf = Vec::with_capacity(10); + let msg_type = self.msg_type(); + buf.push(msg_type); + let id = &self.id; + match self.body { + ResponseBody::Pong { enr_seq, ip, port } => { + let mut s = RlpStream::new(); + s.begin_list(4); + s.append(&id.as_bytes()); + s.append(&enr_seq); + match ip { + IpAddr::V4(addr) => s.append(&(&addr.octets() as &[u8])), + IpAddr::V6(addr) => s.append(&(&addr.octets() as &[u8])), + }; + s.append(&port); + buf.extend_from_slice(&s.out()); + buf + } + ResponseBody::Nodes { total, nodes } => { + let mut s = RlpStream::new(); + s.begin_list(3); + s.append(&id.as_bytes()); + s.append(&total); + + if nodes.is_empty() { + s.begin_list(0); + } else { + s.begin_list(nodes.len()); + for node in nodes { + s.append(&node); + } + } + buf.extend_from_slice(&s.out()); + buf + } + ResponseBody::TalkResp { response } => { + let mut s = RlpStream::new(); + s.begin_list(2); + s.append(&id.as_bytes()); + s.append(&response); + buf.extend_from_slice(&s.out()); + buf + } + } + } + + /// Decodes RLP-encoded bytes into a response message. + fn decode(msg_type: u8, rlp: &Rlp<'_>) -> Result { + let list_len = rlp.item_count()?; + let id = RequestId::decode(rlp.val_at::>(0)?)?; + let response = match msg_type.try_into()? { + MessageType::Pong => { + // Pong Response + if list_len != 4 { + debug!( + "Ping Response has an invalid RLP list length. Expected 4, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + let ip_bytes = rlp.val_at::>(2)?; + let ip = match ip_bytes.len() { + 4 => { + let mut ip = [0u8; 4]; + ip.copy_from_slice(&ip_bytes); + IpAddr::from(ip) + } + 16 => { + let mut ip = [0u8; 16]; + ip.copy_from_slice(&ip_bytes); + let ipv6 = Ipv6Addr::from(ip); + if ipv6.is_loopback() { + // Checking if loopback address since IPv6Addr::to_ipv4 returns + // IPv4 address for IPv6 loopback address. + IpAddr::V6(ipv6) + } else if let Some(ipv4) = ipv6.to_ipv4() { + // If the ipv6 is ipv4 compatible/mapped, simply return the ipv4. + IpAddr::V4(ipv4) + } else { + IpAddr::V6(ipv6) + } + } + _ => { + debug!("Ping Response has incorrect byte length for IP"); + return Err(DecoderError::RlpIncorrectListLen); + } + }; + let port = rlp.val_at::(3)?; + Self { + id, + body: ResponseBody::Pong { + enr_seq: rlp.val_at::(1)?, + ip, + port, + }, + } + } + MessageType::Nodes => { + // Nodes Response + if list_len != 3 { + debug!( + "Nodes Response has an invalid RLP list length. Expected 3, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + + let nodes = { + let enr_list_rlp = rlp.at(2)?; + if enr_list_rlp.is_empty() { + // no records + vec![] + } else { + enr_list_rlp.as_list::()? + } + }; + Self { + id, + body: ResponseBody::Nodes { + total: rlp.val_at::(1)?, + nodes, + }, + } + } + MessageType::TalkResp => { + // Talk Response + if list_len != 2 { + debug!( + "Talk Response has an invalid RLP list length. Expected 2, found {}", + list_len + ); + return Err(DecoderError::RlpIncorrectListLen); + } + let response = rlp.val_at::>(1)?; + Self { + id, + body: ResponseBody::TalkResp { response }, + } + } + _ => unreachable!("Implementation does not adhere to wire protocol"), + }; + Ok(response) + } +} + +impl Response { + /// Determines if the response is a valid response to the given request. + pub fn match_request(&self, req: &RequestBody) -> bool { + match self.body { + ResponseBody::Pong { .. } => matches!(req, RequestBody::Ping { .. }), + ResponseBody::Nodes { .. } => { + matches!(req, RequestBody::FindNode { .. }) + } + ResponseBody::TalkResp { .. } => matches!(req, RequestBody::TalkReq { .. }), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ResponseBody { + /// A PONG response. + Pong { + /// The current ENR sequence number of the responder. + enr_seq: u64, + /// Our external IP address as observed by the responder. + ip: IpAddr, + /// Our external UDP port as observed by the responder. + port: u16, + }, + /// A NODES response. + Nodes { + /// The total number of responses that make up this response. + total: u64, + /// A list of ENR's returned by the responder. + nodes: Vec, + }, + /// The TALK response. + TalkResp { + /// The response for the talk. + response: Vec, + }, +} + +impl std::fmt::Display for ResponseBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResponseBody::Pong { enr_seq, ip, port } => { + write!(f, "PONG: Enr-seq: {enr_seq}, Ip: {ip:?}, Port: {port}") + } + ResponseBody::Nodes { total, nodes } => { + write!(f, "NODES: total: {total}, Nodes: [")?; + let mut first = true; + for id in nodes { + if !first { + write!(f, ", {id}")?; + } else { + write!(f, "{id}")?; + } + first = false; + } + + write!(f, "]") + } + ResponseBody::TalkResp { response } => { + write!(f, "Response: Response {}", hex::encode(response)) + } + } + } +} diff --git a/src/service.rs b/src/service.rs index 68e1ffb62..f55089010 100644 --- a/src/service.rs +++ b/src/service.rs @@ -19,7 +19,7 @@ use self::{ }; use crate::{ error::{RequestError, ResponseError}, - handler::{Handler, HandlerIn, HandlerOut}, + handler::{EnrRequestData, Handler, HandlerIn, HandlerOut}, kbucket::{ self, ConnectionDirection, ConnectionState, FailureReason, InsertResult, KBucketsTable, NodeStatus, UpdateResult, MAX_NODES_PER_BUCKET, @@ -85,7 +85,7 @@ impl Drop for TalkRequest { let response = Response { id: self.id.clone(), - body: ResponseBody::Talk { response: vec![] }, + body: ResponseBody::TalkResp { response: vec![] }, }; debug!("Sending empty TALK response to {}", self.node_address); @@ -120,7 +120,7 @@ impl TalkRequest { let response = Response { id: self.id.clone(), - body: ResponseBody::Talk { response }, + body: ResponseBody::TalkResp { response }, }; self.sender @@ -387,20 +387,60 @@ impl Service { HandlerOut::Response(node_address, response) => { self.handle_rpc_response(node_address, *response); } - HandlerOut::WhoAreYou(whoareyou_ref) => { + HandlerOut::RequestEnr(EnrRequestData::WhoAreYou(whoareyou_ref)) => { // check what our latest known ENR is for this node. if let Some(known_enr) = self.find_enr(&whoareyou_ref.0.node_id) { - if let Err(e) = self.handler_send.send(HandlerIn::WhoAreYou(whoareyou_ref, Some(known_enr))) { + if let Err(e) = self.handler_send.send(HandlerIn::EnrResponse(Some(known_enr), EnrRequestData::WhoAreYou(whoareyou_ref))) { warn!("Failed to send whoareyou {}", e); }; } else { // do not know of this peer debug!("NodeId unknown, requesting ENR. {}", whoareyou_ref.0); - if let Err(e) = self.handler_send.send(HandlerIn::WhoAreYou(whoareyou_ref, None)) { + if let Err(e) = self.handler_send.send(HandlerIn::EnrResponse(None, EnrRequestData::WhoAreYou(whoareyou_ref))) { warn!("Failed to send who are you to unknown enr peer {}", e); } } } + HandlerOut::RequestEnr(EnrRequestData::Nat(relay_initiation)) => { + // Update initiator's Enr if it's in kbuckets + let initiator_enr = relay_initiation.initiator_enr(); + let initiator_key = kbucket::Key::from(initiator_enr.node_id()); + match self.kbuckets.write().entry(&initiator_key) { + kbucket::Entry::Present(ref mut entry, _) => { + let enr = entry.value_mut(); + if enr.seq() < initiator_enr.seq() { + *enr = initiator_enr.clone(); + } + } + kbucket::Entry::Pending(ref mut entry, _) => { + let enr = entry.value_mut(); + if enr.seq() < initiator_enr.seq() { + *enr = initiator_enr.clone(); + } + } + _ => () + } + // check if we know the target node id in our routing table, otherwise + // drop relay attempt. + let target_node_id = relay_initiation.target_node_id(); + let target_key = kbucket::Key::from(target_node_id); + if let kbucket::Entry::Present(entry, _) = self.kbuckets.write().entry(&target_key) { + let target_enr = entry.value().clone(); + if let Err(e) = self.handler_send.send(HandlerIn::EnrResponse(Some(target_enr), EnrRequestData::Nat(relay_initiation))) { + warn!( + "Failed to send target enr to relay process, error: {e}" + ); + } + } else { + let initiator_node_id = relay_initiation.initiator_enr().node_id(); + warn!( + initiator_node_id=%initiator_node_id, + target_node_id=%target_node_id, + "Peer requested relaying to a peer not in k-buckets" + ); + } + }, + HandlerOut::PingAllPeers => self.ping_connected_peers(), HandlerOut::RequestFailed(request_id, error) => { if let RequestError::Timeout = error { debug!("RPC Request timed out. id: {}", request_id); @@ -581,8 +621,8 @@ impl Service { } } kbucket::Entry::Pending(ref mut entry, _) => { - if entry.value().seq() < enr_seq { - let enr = entry.value().clone(); + if entry.value_mut().seq() < enr_seq { + let enr = entry.value_mut().clone(); to_request_enr = Some(enr); } } @@ -622,7 +662,7 @@ impl Service { warn!("Failed to send response {}", e) } } - RequestBody::Talk { protocol, request } => { + RequestBody::TalkReq { protocol, request } => { let req = TalkRequest { id, node_address, @@ -834,11 +874,20 @@ impl Service { updated = true; info!( "Local UDP ip6 socket updated to: {}", - new_ip6, + new_ip6 ); self.send_event(Discv5Event::SocketUpdated( new_ip6, )); + // Notify Handler of socket update + if let Err(e) = + self.handler_send.send(HandlerIn::SocketUpdate( + local_ip6_socket.map(SocketAddr::V6), + new_ip6, + )) + { + warn!("Failed to send socket update to handler: {}", e); + }; } Err(e) => { warn!("Failed to update local UDP ip6 socket. ip6: {}, error: {:?}", new_ip6, e); @@ -858,6 +907,15 @@ impl Service { self.send_event(Discv5Event::SocketUpdated( new_ip4, )); + // Notify Handler of socket update + if let Err(e) = + self.handler_send.send(HandlerIn::SocketUpdate( + local_ip4_socket.map(SocketAddr::V4), + new_ip4, + )) + { + warn!("Failed to send socket update {}", e); + }; } Err(e) => { warn!("Failed to update local UDP socket. ip: {}, error: {:?}", new_ip4, e); @@ -889,7 +947,7 @@ impl Service { } } } - ResponseBody::Talk { response } => { + ResponseBody::TalkResp { response } => { // Send the response to the user match active_request.callback { Some(CallbackResponse::Talk(callback)) => { @@ -981,7 +1039,7 @@ impl Service { request: Vec, callback: oneshot::Sender, RequestError>>, ) { - let request_body = RequestBody::Talk { protocol, request }; + let request_body = RequestBody::TalkReq { protocol, request }; let active_request = ActiveRequest { contact, @@ -1056,13 +1114,16 @@ impl Service { for enr in nodes_to_send.into_iter() { let entry_size = rlp::encode(&enr).len(); // Responses assume that a session is established. Thus, on top of the encoded - // ENR's the packet should be a regular message. A regular message has an IV (16 - // bytes), and a header of 55 bytes. The find-nodes RPC requires 16 bytes for the ID and the - // `total` field. Also there is a 16 byte HMAC for encryption and an extra byte for - // RLP encoding. + // ENR's the packet should be a session message, which is the same data + // structure as a regular message. + // A session message has an IV (16 bytes), and a header of 55 bytes. The + // find-nodes RPC requires 16 bytes for the ID and the `total` field. Also there + // is a 16 byte HMAC for encryption and an extra byte for RLP encoding. + // + // We could also be responding via an authheader (this message could be in + // contained in a handshake message) which can take up to 282 bytes in + // the header, leaving even less space for the NODES response. // - // We could also be responding via an authheader which can take up to 282 bytes in its - // header. // As most messages will be normal messages we will try and pack as many ENR's we // can in and drop the response packet if a user requests an auth message of a very // packed response. @@ -1202,7 +1263,7 @@ impl Service { let must_update_enr = match self.kbuckets.write().entry(&key) { kbucket::Entry::Present(entry, _) => entry.value().seq() < enr.seq(), - kbucket::Entry::Pending(mut entry, _) => entry.value().seq() < enr.seq(), + kbucket::Entry::Pending(mut entry, _) => entry.value_mut().seq() < enr.seq(), _ => false, }; diff --git a/src/service/test.rs b/src/service/test.rs index b2860700e..d02f8394c 100644 --- a/src/service/test.rs +++ b/src/service/test.rs @@ -173,7 +173,7 @@ async fn test_connection_direction_on_inject_session_established() { let ip = std::net::Ipv4Addr::LOCALHOST; let enr = EnrBuilder::new("v4") .ip4(ip) - .udp4(10001) + .udp4(10003) .build(&enr_key1) .unwrap(); @@ -181,7 +181,7 @@ async fn test_connection_direction_on_inject_session_established() { let ip2 = std::net::Ipv4Addr::LOCALHOST; let enr2 = EnrBuilder::new("v4") .ip4(ip2) - .udp4(10002) + .udp4(10004) .build(&enr_key2) .unwrap(); diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 209348efc..66c07d857 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -1,4 +1,4 @@ -use crate::{packet::ProtocolIdentity, Executor}; +use crate::{packet::ProtocolIdentity, Executor, IpMode}; use parking_lot::RwLock; use recv::*; use send::*; @@ -24,7 +24,7 @@ pub use filter::{ FilterConfig, }; pub use recv::InboundPacket; -pub use send::OutboundPacket; +pub use send::{Outbound, OutboundPacket}; /// Configuration for the sockets to listen on. /// @@ -47,6 +47,42 @@ pub enum ListenConfig { }, } +impl ListenConfig { + pub fn ipv4(&self) -> Option { + match self { + ListenConfig::Ipv4 { ip, .. } | ListenConfig::DualStack { ipv4: ip, .. } => Some(*ip), + _ => None, + } + } + + pub fn ipv6(&self) -> Option { + match self { + ListenConfig::Ipv6 { ip, .. } | ListenConfig::DualStack { ipv6: ip, .. } => Some(*ip), + _ => None, + } + } + + pub fn ipv4_port(&self) -> Option { + match self { + ListenConfig::Ipv4 { port, .. } + | ListenConfig::DualStack { + ipv4_port: port, .. + } => Some(*port), + _ => None, + } + } + + pub fn ipv6_port(&self) -> Option { + match self { + ListenConfig::Ipv6 { port, .. } + | ListenConfig::DualStack { + ipv6_port: port, .. + } => Some(*port), + _ => None, + } + } +} + /// Convenience objects for setting up the recv handler. pub struct SocketConfig { /// The executor to spawn the tasks. @@ -65,7 +101,7 @@ pub struct SocketConfig { /// Creates the UDP socket and handles the exit futures for the send/recv UDP handlers. pub struct Socket { - pub send: mpsc::Sender, + pub send: mpsc::Sender, pub recv: mpsc::Receiver, sender_exit: Option>, recv_exit: Option>, @@ -243,6 +279,14 @@ impl ListenConfig { }, } } + + pub fn ip_mode(&self) -> IpMode { + match self { + ListenConfig::Ipv4 { .. } => IpMode::Ip4, + ListenConfig::Ipv6 { .. } => IpMode::Ip6, + ListenConfig::DualStack { .. } => IpMode::DualStack, + } + } } impl Default for ListenConfig { diff --git a/src/socket/recv.rs b/src/socket/recv.rs index a300a09e0..daeb5723e 100644 --- a/src/socket/recv.rs +++ b/src/socket/recv.rs @@ -185,7 +185,20 @@ impl RecvHandler { match Packet::decode::

(&self.node_id, &recv_buffer[..length]) { Ok(p) => p, Err(e) => { - debug!("Packet decoding failed: {:?}", e); // could not decode the packet, drop it + // This could be a packet to keep a NAT hole punched for this node in the + // sender's NAT, hence only serves purpose for the sender. + if length == 0 { + debug!( + "Empty packet, possibly to keep a hole punched, dropping. src: {}", + src_address + ); + } else { + // Could not decode the packet, drop it. + debug!( + "Packet decoding failed, src: {}, error: {:?}", + src_address, e + ); + } return; } }; diff --git a/src/socket/send.rs b/src/socket/send.rs index 31f52180d..4f051ab43 100644 --- a/src/socket/send.rs +++ b/src/socket/send.rs @@ -1,5 +1,6 @@ //! This is a standalone task that encodes and sends Discv5 UDP packets use crate::{metrics::METRICS, node_info::NodeAddress, packet::*, Executor}; +use derive_more::From; use std::{net::SocketAddr, sync::Arc}; use tokio::{ net::UdpSocket, @@ -7,6 +8,21 @@ use tokio::{ }; use tracing::{debug, error, trace, warn}; +#[derive(From)] +pub enum Outbound { + Packet(OutboundPacket), + KeepHolePunched(SocketAddr), +} + +impl Outbound { + pub fn dst(&self) -> &SocketAddr { + match self { + Self::Packet(packet) => &packet.node_address.socket_addr, + Self::KeepHolePunched(dst) => dst, + } + } +} + pub struct OutboundPacket { /// The destination node address pub node_address: NodeAddress, @@ -21,7 +37,7 @@ pub(crate) struct SendHandler { /// The UDP send socket for IPv6. send_ipv6: Option>, /// The channel to respond to send requests. - handler_recv: mpsc::Receiver, + handler_recv: mpsc::Receiver, /// Exit channel to shutdown the handler. exit: oneshot::Receiver<()>, } @@ -39,7 +55,7 @@ impl SendHandler { executor: Box, send_ipv4: Option>, send_ipv6: Option>, - ) -> (mpsc::Sender, oneshot::Sender<()>) { + ) -> (mpsc::Sender, oneshot::Sender<()>) { let (exit_send, exit) = oneshot::channel(); let (handler_send, handler_recv) = mpsc::channel(30); @@ -62,13 +78,20 @@ impl SendHandler { async fn start(&mut self) { loop { tokio::select! { - Some(packet) = self.handler_recv.recv() => { - let encoded_packet = packet.packet.encode::

(&packet.node_address.node_id); - if encoded_packet.len() > MAX_PACKET_SIZE { - warn!("Sending packet larger than max size: {} max: {}", encoded_packet.len(), MAX_PACKET_SIZE); - } - let addr = &packet.node_address.socket_addr; - if let Err(e) = self.send(&encoded_packet, addr).await { + Some(outbound) = self.handler_recv.recv() => { + let (addr, encoded_packet) = match outbound { + Outbound::Packet(outbound_packet) => { + let dst_id = outbound_packet.node_address.node_id; + let encoded_packet = outbound_packet.packet.encode::

(&dst_id); + if encoded_packet.len() > MAX_PACKET_SIZE { + warn!("Sending packet larger than max size: {} max: {}", encoded_packet.len(), MAX_PACKET_SIZE); + } + let dst_addr = outbound_packet.node_address.socket_addr; + (dst_addr, encoded_packet) + } + Outbound::KeepHolePunched(dst) => (dst, vec![]), + }; + if let Err(e) = self.send(&encoded_packet, &addr).await { match e { Error::Io(e) => { trace!("Could not send packet to {addr} . Error: {e}");