diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 1f965ef765e..618055726a2 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -95,7 +95,7 @@ stepScriptSTM scriptVar = do return x -data Connection m = Connection +data Connection m addr = Connection { -- | Attenuated channels of a connection. -- connChannelLocal :: !(AttenuatedChannel m) @@ -109,6 +109,10 @@ data Connection m = Connection -- open. -- , connState :: !ConnectionState + + -- | Provider of this Connection, so one can know its origin and decide + -- accordingly when accepting/connecting a connection. + , connProvider :: !addr } @@ -132,7 +136,7 @@ data ConnectionState deriving (Eq, Show) -dualConnection :: Connection m -> Connection m +dualConnection :: Connection m addr -> Connection m addr dualConnection conn@Connection { connChannelLocal, connChannelRemote } = conn { connChannelLocal = connChannelRemote , connChannelRemote = connChannelLocal @@ -149,7 +153,7 @@ mkConnection :: ( MonadLabelledSTM m (SnocketTrace m (TestAddress addr))) -> BearerInfo -> ConnectionId (TestAddress addr) - -> STM m (Connection m) + -> STM m (Connection m (TestAddress addr)) mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = do (channelLocal, channelRemote) <- newConnectedAttenuatedChannelPair @@ -176,6 +180,7 @@ mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = channelRemote (biSDUSize bearerInfo) SYN_SENT + localAddress -- | Connection id independent of who provisioned the connection. 'NormalisedId' @@ -209,7 +214,7 @@ data NetworkState m addr = NetworkState { -- | Registry of active connections. -- - nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m)), + nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m addr)), -- | Get an unused ephemeral address. -- @@ -448,7 +453,7 @@ data FD_ m addr -- assigned to it. This corresponds to 'SYN_SENT' state. -- | FDConnecting !(ConnectionId addr) - !(Connection m) + !(Connection m addr) -- | 'FD_' for snockets in connected state. -- @@ -458,7 +463,7 @@ data FD_ m addr | FDConnected !(ConnectionId addr) -- ^ local and remote addresses - !(Connection m) + !(Connection m addr) -- ^ connection -- | 'FD_' of a closed file descriptor; we keep 'ConnectionId' just for @@ -982,9 +987,14 @@ mkSnocket state tr = Snocket { getLocalAddr let connId = ConnectionId localAddress (cwiAddress cwi) case Map.lookup (normaliseId connId) connMap of - Nothing -> return False - Just (Connection _ _ _ SYN_SENT) -> return True - _ -> return False + Nothing -> + return False + Just (Connection _ _ _ SYN_SENT provider) -> + return ( provider /= localAddress + || localAddress == cwiAddress cwi + ) + _ -> + return False accept_ = Accept $ \unmask -> do bracketOnError @@ -1062,6 +1072,7 @@ mkSnocket state tr = Snocket { getLocalAddr , connChannelRemote = channelRemote , connSDUSize = sduSize , connState = ESTABLISHED + , connProvider = remoteAddress }) traceWith tr (WithAddr (Just (localAddress connId)) Nothing