Skip to content

Commit

Permalink
Fix Snocket bug in the accept/connect
Browse files Browse the repository at this point in the history
Basically, if I connect to someone and someone connects to me, before the connect returns (and before the remote accept returns as well) the local accept can return first masking itself as the remote one because we have no way to distinguish directions.
  • Loading branch information
bolt12 committed Oct 29, 2021
1 parent ca1b03f commit 6ecf01c
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions ouroboros-network-framework/src/Simulation/Network/Snocket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ stepScriptSTM scriptVar = do
return x


data Connection m = Connection
data Connection m addr = Connection
{ -- | Attenuated channels of a connection.
--
connChannelLocal :: !(AttenuatedChannel m)
Expand All @@ -109,6 +109,10 @@ data Connection m = Connection
-- open.
--
, connState :: !ConnectionState

-- | Provider of this Connection, so one can know its origin and decide
-- accordingly when accepting/connecting a connection.
, connProvider :: !addr
}


Expand All @@ -132,7 +136,7 @@ data ConnectionState
deriving (Eq, Show)


dualConnection :: Connection m -> Connection m
dualConnection :: Connection m addr -> Connection m addr
dualConnection conn@Connection { connChannelLocal, connChannelRemote } =
conn { connChannelLocal = connChannelRemote
, connChannelRemote = connChannelLocal
Expand All @@ -149,7 +153,7 @@ mkConnection :: ( MonadLabelledSTM m
(SnocketTrace m (TestAddress addr)))
-> BearerInfo
-> ConnectionId (TestAddress addr)
-> STM m (Connection m)
-> STM m (Connection m (TestAddress addr))
mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = do
(channelLocal, channelRemote) <-
newConnectedAttenuatedChannelPair
Expand All @@ -176,6 +180,7 @@ mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } =
channelRemote
(biSDUSize bearerInfo)
SYN_SENT
localAddress


-- | Connection id independent of who provisioned the connection. 'NormalisedId'
Expand Down Expand Up @@ -209,7 +214,7 @@ data NetworkState m addr = NetworkState {

-- | Registry of active connections.
--
nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m)),
nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m addr)),

-- | Get an unused ephemeral address.
--
Expand Down Expand Up @@ -448,7 +453,7 @@ data FD_ m addr
-- assigned to it. This corresponds to 'SYN_SENT' state.
--
| FDConnecting !(ConnectionId addr)
!(Connection m)
!(Connection m addr)

-- | 'FD_' for snockets in connected state.
--
Expand All @@ -458,7 +463,7 @@ data FD_ m addr
| FDConnected
!(ConnectionId addr)
-- ^ local and remote addresses
!(Connection m)
!(Connection m addr)
-- ^ connection

-- | 'FD_' of a closed file descriptor; we keep 'ConnectionId' just for
Expand Down Expand Up @@ -982,9 +987,14 @@ mkSnocket state tr = Snocket { getLocalAddr
let connId = ConnectionId localAddress (cwiAddress cwi)

case Map.lookup (normaliseId connId) connMap of
Nothing -> return False
Just (Connection _ _ _ SYN_SENT) -> return True
_ -> return False
Nothing ->
return False
Just (Connection _ _ _ SYN_SENT provider) ->
return ( provider /= localAddress
|| localAddress == cwiAddress cwi
)
_ ->
return False

accept_ = Accept $ \unmask -> do
bracketOnError
Expand Down Expand Up @@ -1062,6 +1072,7 @@ mkSnocket state tr = Snocket { getLocalAddr
, connChannelRemote = channelRemote
, connSDUSize = sduSize
, connState = ESTABLISHED
, connProvider = remoteAddress
})

traceWith tr (WithAddr (Just (localAddress connId)) Nothing
Expand Down

0 comments on commit 6ecf01c

Please sign in to comment.