From 5f744ee1f7a1e1442f0e1327de30f08449893349 Mon Sep 17 00:00:00 2001 From: Tim Li <47233368+timl3136@users.noreply.github.com> Date: Tue, 7 May 2024 16:14:18 -0700 Subject: [PATCH] Added more unit tests for history/handler.go (#5984) --- service/history/handler/handler_test.go | 442 ++++++++++++++++++++++++ 1 file changed, 442 insertions(+) diff --git a/service/history/handler/handler_test.go b/service/history/handler/handler_test.go index 250dd7842d9..81808478106 100644 --- a/service/history/handler/handler_test.go +++ b/service/history/handler/handler_test.go @@ -26,6 +26,7 @@ import ( "math/rand" "sync/atomic" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/mock" @@ -1194,6 +1195,447 @@ func (s *handlerSuite) TestDescribeHistoryHost() { } } +func (s *handlerSuite) TestRemoveTask() { + now := time.Now() + testInput := map[string]struct { + request *types.RemoveTaskRequest + expectedError bool + mockFn func() + }{ + "transfer task": { + request: &types.RemoveTaskRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTransfer)), + TaskID: int64(1), + }, + expectedError: false, + mockFn: func() { + s.mockResource.ExecutionMgr.On("CompleteTransferTask", mock.Anything, &persistence.CompleteTransferTaskRequest{ + TaskID: int64(1), + }).Return(nil).Once() + }, + }, + "timer task": { + request: &types.RemoveTaskRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTimer)), + TaskID: int64(1), + VisibilityTimestamp: common.Int64Ptr(int64(now.UnixNano())), + }, + expectedError: false, + mockFn: func() { + s.mockResource.ExecutionMgr.On("CompleteTimerTask", mock.Anything, mock.Anything).Return(nil).Once() + }, + }, + "replication task": { + request: &types.RemoveTaskRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeReplication)), + TaskID: int64(1), + }, + expectedError: false, + mockFn: func() { + s.mockResource.ExecutionMgr.On("CompleteReplicationTask", mock.Anything, &persistence.CompleteReplicationTaskRequest{ + TaskID: int64(1), + }).Return(nil).Once() + }, + }, + "cross cluster task": { + request: &types.RemoveTaskRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeCrossCluster)), + TaskID: int64(1), + }, + expectedError: false, + mockFn: func() { + s.mockResource.ExecutionMgr.On("CompleteCrossClusterTask", mock.Anything, &persistence.CompleteCrossClusterTaskRequest{ + TaskID: int64(1), + }).Return(nil).Once() + }, + }, + "invalid": { + request: &types.RemoveTaskRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(100)), + }, + expectedError: true, + mockFn: func() {}, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := s.handler.RemoveTask(context.Background(), input.request) + if input.expectedError { + s.Error(err) + } else { + s.NoError(err) + } + }) + + } +} + +func (s *handlerSuite) TestCloseShard() { + request := &types.CloseShardRequest{ + ShardID: 0, + } + + s.mockShardController.EXPECT().RemoveEngineForShard(0).Return().Times(1) + err := s.handler.CloseShard(context.Background(), request) + s.NoError(err) +} + +func (s *handlerSuite) TestResetQueue() { + testInput := map[string]struct { + request *types.ResetQueueRequest + expectedError bool + mockFn func() + }{ + "getEngine error": { + request: &types.ResetQueueRequest{ + ShardID: 0, + }, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(nil, errors.New("error")).Times(1) + }, + }, + "transfer task": { + request: &types.ResetQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTransfer)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ResetTransferQueue(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "timer task": { + request: &types.ResetQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTimer)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ResetTimerQueue(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "cros cluster task": { + request: &types.ResetQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeCrossCluster)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ResetCrossClusterQueue(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "invalid task": { + request: &types.ResetQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(100)), + }, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := s.handler.ResetQueue(context.Background(), input.request) + if input.expectedError { + s.Error(err) + } else { + s.NoError(err) + } + }) + + } +} + +func (s *handlerSuite) TestDescribeQueue() { + testInput := map[string]struct { + request *types.DescribeQueueRequest + expectedError bool + mockFn func() + }{ + "getEngine error": { + request: &types.DescribeQueueRequest{ + ShardID: 0, + }, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(nil, errors.New("error")).Times(1) + }, + }, + "transfer task": { + request: &types.DescribeQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTransfer)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().DescribeTransferQueue(gomock.Any(), gomock.Any()).Return(&types.DescribeQueueResponse{}, nil).Times(1) + }, + }, + "timer task": { + request: &types.DescribeQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeTimer)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().DescribeTimerQueue(gomock.Any(), gomock.Any()).Return(&types.DescribeQueueResponse{}, nil).Times(1) + }, + }, + "cross cluster task": { + request: &types.DescribeQueueRequest{ + ShardID: 0, + Type: common.Int32Ptr(int32(common.TaskTypeCrossCluster)), + }, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().DescribeCrossClusterQueue(gomock.Any(), gomock.Any()).Return(&types.DescribeQueueResponse{}, nil).Times(1) + }, + }, + "invalid task": { + request: &types.DescribeQueueRequest{ + Type: common.Int32Ptr(int32(100)), + }, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngineForShard(0).Return(s.mockEngine, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + resp, err := s.handler.DescribeQueue(context.Background(), input.request) + if input.expectedError { + s.Nil(resp) + s.Error(err) + } else { + s.NotNil(resp) + s.NoError(err) + } + }) + + } +} + +func (s *handlerSuite) TestDescribeMutableState() { + validInput := &types.DescribeMutableStateRequest{ + DomainUUID: testDomainID, + Execution: &types.WorkflowExecution{ + WorkflowID: testWorkflowID, + RunID: testValidUUID, + }, + } + testInput := map[string]struct { + request *types.DescribeMutableStateRequest + expectedError bool + mockFn func() + }{ + "empty domainID": { + request: &types.DescribeMutableStateRequest{ + DomainUUID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "getEngine error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "getMutableState error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().DescribeMutableState(gomock.Any(), validInput).Return(nil, errors.New("error")).Times(1) + }, + }, + "success": { + request: validInput, + expectedError: false, + mockFn: func() { + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().DescribeMutableState(gomock.Any(), validInput).Return(&types.DescribeMutableStateResponse{}, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + resp, err := s.handler.DescribeMutableState(context.Background(), input.request) + if input.expectedError { + s.Nil(resp) + s.Error(err) + } else { + s.NotNil(resp) + s.NoError(err) + } + }) + } +} + +func (s *handlerSuite) TestGetMutableState() { + validInput := &types.GetMutableStateRequest{ + DomainUUID: testDomainID, + Execution: &types.WorkflowExecution{ + WorkflowID: testWorkflowID, + RunID: testValidUUID, + }, + } + testInput := map[string]struct { + request *types.GetMutableStateRequest + expectedError bool + mockFn func() + }{ + "empty domainID": { + request: &types.GetMutableStateRequest{ + DomainUUID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "ratelimit": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "getEngine error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "getMutableState error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().GetMutableState(gomock.Any(), validInput).Return(nil, errors.New("error")).Times(1) + }, + }, + "success": { + request: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().GetMutableState(gomock.Any(), validInput).Return(&types.GetMutableStateResponse{}, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + resp, err := s.handler.GetMutableState(context.Background(), input.request) + if input.expectedError { + s.Nil(resp) + s.Error(err) + } else { + s.NotNil(resp) + s.NoError(err) + } + }) + } +} + +func (s *handlerSuite) TestPollMutableState() { + validInput := &types.PollMutableStateRequest{ + DomainUUID: testDomainID, + Execution: &types.WorkflowExecution{ + WorkflowID: testWorkflowID, + RunID: testValidUUID, + }, + } + testInput := map[string]struct { + request *types.PollMutableStateRequest + expectedError bool + mockFn func() + }{ + "empty domainID": { + request: &types.PollMutableStateRequest{ + DomainUUID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "ratelimit": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "getEngine error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "getMutableState error": { + request: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().PollMutableState(gomock.Any(), validInput).Return(nil, errors.New("error")).Times(1) + }, + }, + "success": { + request: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().PollMutableState(gomock.Any(), validInput).Return(&types.PollMutableStateResponse{}, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + resp, err := s.handler.PollMutableState(context.Background(), input.request) + if input.expectedError { + s.Nil(resp) + s.Error(err) + } else { + s.NotNil(resp) + s.NoError(err) + } + }) + } +} + func (s *handlerSuite) TestGetCrossClusterTasks() { numShards := 10 targetCluster := cluster.TestAlternativeClusterName