Skip to content

Commit

Permalink
Handle context cancellation properly (#428)
Browse files Browse the repository at this point in the history
* fix(cancellation): handle message cancellation properly

a previous fix #391, attempted to address a lock that occurred on context cancel. however in doing
so, it introduced a new lock. essentially, if a message was not sent to the
requestmanager/responsemanager go routine, waiting for a response to that message could last
indefinitely. the previous fix therefore stopped waiting when the calling context cancelled. However
once a message reaches the go routine of the requestmanager/responsemanager, it's important that
it's processed to completion, so that the the message loop doesn't lock. The proper fix is to detect
when the message is sent successfully, and if so, wait for it to be processed. If it isn't sent, we
can safely abort the go routine immediately.

* fix(race): resolve race condition with test responses
  • Loading branch information
hannahhoward authored Aug 2, 2023
1 parent 6a2fa5c commit cb19a45
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 42 deletions.
54 changes: 34 additions & 20 deletions requestmanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ func (rm *RequestManager) NewRequest(ctx context.Context,

inProgressRequestChan := make(chan inProgressRequest)

rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
err := rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
if err != nil {
return rm.emptyResponse()
}
var receivedInProgressRequest inProgressRequest
select {
case <-rm.ctx.Done():
return rm.emptyResponse()
case <-ctx.Done():
return rm.emptyResponse()
case receivedInProgressRequest = <-inProgressRequestChan:
}

Expand Down Expand Up @@ -283,12 +284,13 @@ func (rm *RequestManager) cancelRequestAndClose(requestID graphsync.RequestID,
// CancelRequest cancels the given request ID and waits for the request to terminate
func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
terminated := make(chan error, 1)
rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
err := rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-terminated:
return err
}
Expand All @@ -300,19 +302,20 @@ func (rm *RequestManager) ProcessResponses(p peer.ID,
responses []gsmsg.GraphSyncResponse,
blks []blocks.Block) {

rm.send(&processResponsesMessage{p, responses, blks}, nil)
_ = rm.send(&processResponsesMessage{p, responses, blks}, nil)
}

// UnpauseRequest unpauses a request that was paused in a block hook based request ID
// Can also send extensions with unpause
func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -321,12 +324,13 @@ func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsyn
// PauseRequest pauses an in progress request (may take 1 or more blocks to process)
func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -335,26 +339,27 @@ func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.
// UpdateRequest updates an in progress request
func (rm *RequestManager) UpdateRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
}

// GetRequestTask gets data for the given task in the request queue
func (rm *RequestManager) GetRequestTask(p peer.ID, task *peertask.Task, requestExecutionChan chan executor.RequestTask) {
rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
_ = rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
}

// ReleaseRequestTask releases a task request the requestQueue
func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err error) {
done := make(chan struct{}, 1)
rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
_ = rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -364,7 +369,7 @@ func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err
// PeerState gets stats on all outgoing requests for a given peer
func (rm *RequestManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand Down Expand Up @@ -392,11 +397,20 @@ func (rm *RequestManager) Shutdown() {
rm.cancel()
}

func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) {
func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down
3 changes: 2 additions & 1 deletion requestmanager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (rm *RequestManager) run() {
for {
select {
case message := <-rm.messages:

message.handle(rm)
case <-rm.ctx.Done():
return
Expand Down Expand Up @@ -304,13 +305,13 @@ func (rm *RequestManager) processResponses(p peer.ID,
for _, blk := range blks {
blkMap[blk.Cid()] = blk.RawData()
}
rm.updateLastResponses(filteredResponses)
for _, response := range filteredResponses {
reconciledLoader := rm.inProgressRequestStatuses[response.RequestID()].reconciledLoader
if reconciledLoader != nil {
reconciledLoader.IngestResponse(response.Metadata(), trace.LinkFromContext(ctx), blkMap)
}
}
rm.updateLastResponses(filteredResponses)
rm.processTerminations(filteredResponses)
log.Debugf("end processing responses for peer %s", p)
}
Expand Down
55 changes: 34 additions & 21 deletions responsemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,19 @@ func New(ctx context.Context,

// ProcessRequests processes incoming requests for the given peer
func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, requests []gsmsg.GraphSyncRequest) {
rm.send(&processRequestsMessage{p, requests}, ctx.Done())
_ = rm.send(&processRequestsMessage{p, requests}, ctx.Done())
}

// UnpauseResponse unpauses a response that was previously paused
func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -177,12 +178,13 @@ func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphs
// PauseResponse pauses an in progress response (may take 1 or more blocks to process)
func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -191,12 +193,13 @@ func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsyn
// CancelResponse cancels an in progress response
func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
err := rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -205,12 +208,13 @@ func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsy
// UpdateRequest updates an in progress response
func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -219,7 +223,7 @@ func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsy
// Synchronize is a utility method that blocks until all current messages are processed
func (rm *ResponseManager) synchronize() {
sync := make(chan error)
rm.send(&synchronizeMessage{sync}, nil)
_ = rm.send(&synchronizeMessage{sync}, nil)
select {
case <-rm.ctx.Done():
case <-sync:
Expand All @@ -228,18 +232,18 @@ func (rm *ResponseManager) synchronize() {

// StartTask starts the given task from the peer task queue
func (rm *ResponseManager) StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- queryexecutor.ResponseTask) {
rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
_ = rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
}

// GetUpdates is called to read pending updates for a task and clear them
func (rm *ResponseManager) GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) {
rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
_ = rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
}

// FinishTask marks a task from the task queue as done
func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error) {
done := make(chan struct{}, 1)
rm.send(&finishTaskRequest{task, p, err, done}, nil)
_ = rm.send(&finishTaskRequest{task, p, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -249,7 +253,7 @@ func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error)
// CloseWithNetworkError closes a request due to a network error
func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID) {
done := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
_ = rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -259,7 +263,7 @@ func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID)
// TerminateRequest indicates a request has finished sending data and should no longer be tracked
func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
done := make(chan struct{}, 1)
rm.send(&terminateRequestMessage{requestID, done}, nil)
_ = rm.send(&terminateRequestMessage{requestID, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -269,7 +273,7 @@ func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
// PeerState gets current state of the outgoing responses for a given peer
func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand All @@ -278,11 +282,20 @@ func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
}
}

func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) {
func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down

0 comments on commit cb19a45

Please sign in to comment.