From 6ecf01c3a300b5c10fc58a0411bb1f41c97b0457 Mon Sep 17 00:00:00 2001 From: Armando Santos Date: Mon, 25 Oct 2021 11:26:41 +0100 Subject: [PATCH] Fix Snocket bug in the accept/connect Basically, if I connect to someone and someone connects to me, before the connect returns (and before the remote accept returns as well) the local accept can return first masking itself as the remote one because we have no way to distinguish directions. --- .../src/Simulation/Network/Snocket.hs | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 1f965ef765e..618055726a2 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -95,7 +95,7 @@ stepScriptSTM scriptVar = do return x -data Connection m = Connection +data Connection m addr = Connection { -- | Attenuated channels of a connection. -- connChannelLocal :: !(AttenuatedChannel m) @@ -109,6 +109,10 @@ data Connection m = Connection -- open. -- , connState :: !ConnectionState + + -- | Provider of this Connection, so one can know its origin and decide + -- accordingly when accepting/connecting a connection. + , connProvider :: !addr } @@ -132,7 +136,7 @@ data ConnectionState deriving (Eq, Show) -dualConnection :: Connection m -> Connection m +dualConnection :: Connection m addr -> Connection m addr dualConnection conn@Connection { connChannelLocal, connChannelRemote } = conn { connChannelLocal = connChannelRemote , connChannelRemote = connChannelLocal @@ -149,7 +153,7 @@ mkConnection :: ( MonadLabelledSTM m (SnocketTrace m (TestAddress addr))) -> BearerInfo -> ConnectionId (TestAddress addr) - -> STM m (Connection m) + -> STM m (Connection m (TestAddress addr)) mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = do (channelLocal, channelRemote) <- newConnectedAttenuatedChannelPair @@ -176,6 +180,7 @@ mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = channelRemote (biSDUSize bearerInfo) SYN_SENT + localAddress -- | Connection id independent of who provisioned the connection. 'NormalisedId' @@ -209,7 +214,7 @@ data NetworkState m addr = NetworkState { -- | Registry of active connections. -- - nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m)), + nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m addr)), -- | Get an unused ephemeral address. -- @@ -448,7 +453,7 @@ data FD_ m addr -- assigned to it. This corresponds to 'SYN_SENT' state. -- | FDConnecting !(ConnectionId addr) - !(Connection m) + !(Connection m addr) -- | 'FD_' for snockets in connected state. -- @@ -458,7 +463,7 @@ data FD_ m addr | FDConnected !(ConnectionId addr) -- ^ local and remote addresses - !(Connection m) + !(Connection m addr) -- ^ connection -- | 'FD_' of a closed file descriptor; we keep 'ConnectionId' just for @@ -982,9 +987,14 @@ mkSnocket state tr = Snocket { getLocalAddr let connId = ConnectionId localAddress (cwiAddress cwi) case Map.lookup (normaliseId connId) connMap of - Nothing -> return False - Just (Connection _ _ _ SYN_SENT) -> return True - _ -> return False + Nothing -> + return False + Just (Connection _ _ _ SYN_SENT provider) -> + return ( provider /= localAddress + || localAddress == cwiAddress cwi + ) + _ -> + return False accept_ = Accept $ \unmask -> do bracketOnError @@ -1062,6 +1072,7 @@ mkSnocket state tr = Snocket { getLocalAddr , connChannelRemote = channelRemote , connSDUSize = sduSize , connState = ESTABLISHED + , connProvider = remoteAddress }) traceWith tr (WithAddr (Just (localAddress connId)) Nothing