diff --git a/common/persistence/sql/sql_execution_store_test.go b/common/persistence/sql/sql_execution_store_test.go index e9c58d9c22f..a720d962047 100644 --- a/common/persistence/sql/sql_execution_store_test.go +++ b/common/persistence/sql/sql_execution_store_test.go @@ -32,9 +32,11 @@ import ( "github.com/stretchr/testify/require" "github.com/uber/cadence/common" + "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/serialization" "github.com/uber/cadence/common/persistence/sql/sqlplugin" + "github.com/uber/cadence/common/types" ) func TestDeleteCurrentWorkflowExecution(t *testing.T) { @@ -1883,3 +1885,73 @@ func TestDeleteWorkflowExecution(t *testing.T) { }) } } + +func TestTxExecuteShardLocked(t *testing.T) { + tests := []struct { + name string + mockSetup func(*sqlplugin.MockDB, *sqlplugin.MockTx) + operation string + rangeID int64 + fn func(sqlplugin.Tx) error + wantError error + }{ + { + name: "Success", + mockSetup: func(mockDB *sqlplugin.MockDB, mockTx *sqlplugin.MockTx) { + mockDB.EXPECT().BeginTx(gomock.Any(), gomock.Any()).Return(mockTx, nil) + mockTx.EXPECT().ReadLockShards(gomock.Any(), gomock.Any()).Return(11, nil) + mockTx.EXPECT().Commit().Return(nil) + }, + operation: "Insert", + rangeID: 11, + fn: func(sqlplugin.Tx) error { return nil }, + wantError: nil, + }, + { + name: "Error", + mockSetup: func(mockDB *sqlplugin.MockDB, mockTx *sqlplugin.MockTx) { + mockDB.EXPECT().BeginTx(gomock.Any(), gomock.Any()).Return(mockTx, nil) + mockTx.EXPECT().ReadLockShards(gomock.Any(), gomock.Any()).Return(11, nil) + mockTx.EXPECT().Rollback().Return(nil) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(false) + mockDB.EXPECT().IsTimeoutError(gomock.Any()).Return(false) + mockDB.EXPECT().IsThrottlingError(gomock.Any()).Return(false) + }, + operation: "Insert", + rangeID: 11, + fn: func(sqlplugin.Tx) error { return errors.New("error") }, + wantError: &types.InternalServiceError{Message: "Insert operation failed. Error: error"}, + }, + { + name: "Error - shard ownership lost", + mockSetup: func(mockDB *sqlplugin.MockDB, mockTx *sqlplugin.MockTx) { + mockDB.EXPECT().BeginTx(gomock.Any(), gomock.Any()).Return(mockTx, nil) + mockTx.EXPECT().ReadLockShards(gomock.Any(), gomock.Any()).Return(12, nil) + mockTx.EXPECT().Rollback().Return(nil) + }, + operation: "Insert", + rangeID: 11, + fn: func(sqlplugin.Tx) error { return errors.New("error") }, + wantError: &persistence.ShardOwnershipLostError{ShardID: 0, Msg: "Failed to lock shard. Previous range ID: 11; new range ID: 12"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := sqlplugin.NewMockDB(ctrl) + mockTx := sqlplugin.NewMockTx(ctrl) + tt.mockSetup(mockDB, mockTx) + + s := &sqlExecutionStore{ + shardID: 0, + sqlStore: sqlStore{ + db: mockDB, + logger: testlogger.New(t), + }, + } + + gotError := s.txExecuteShardLocked(context.Background(), 0, tt.operation, tt.rangeID, tt.fn) + assert.Equal(t, tt.wantError, gotError) + }) + } +}