From efef2e00f3c6668519edaecff9cd18899f971aa4 Mon Sep 17 00:00:00 2001 From: Ilya Ozherelyev Date: Fri, 29 Dec 2023 22:44:49 +0100 Subject: [PATCH] Added tests for visibility sampling wrapper (#5564) * Simplify sampled visibility manager --- common/persistence/client/factory.go | 18 +- .../visibilitySamplingClient_test.go | 12 +- .../wrappers/sampled/tokenbucketfactory.go | 75 +++++ .../sampled/tokenbucketfactory_test.go | 42 +++ .../sampled/visibility_manager.go} | 222 ++++++--------- .../sampled/visibility_manager_test.go | 265 ++++++++++++++++++ 6 files changed, 497 insertions(+), 137 deletions(-) create mode 100644 common/persistence/wrappers/sampled/tokenbucketfactory.go create mode 100644 common/persistence/wrappers/sampled/tokenbucketfactory_test.go rename common/persistence/{visibilitySamplingClient.go => wrappers/sampled/visibility_manager.go} (53%) create mode 100644 common/persistence/wrappers/sampled/visibility_manager_test.go diff --git a/common/persistence/client/factory.go b/common/persistence/client/factory.go index 9148916ac21..240d043f938 100644 --- a/common/persistence/client/factory.go +++ b/common/persistence/client/factory.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/uber/cadence/common" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/config" es "github.com/uber/cadence/common/elasticsearch" "github.com/uber/cadence/common/log" @@ -38,6 +39,7 @@ import ( "github.com/uber/cadence/common/persistence/sql" "github.com/uber/cadence/common/persistence/wrappers/errorinjectors" "github.com/uber/cadence/common/persistence/wrappers/ratelimited" + "github.com/uber/cadence/common/persistence/wrappers/sampled" pnt "github.com/uber/cadence/common/pinot" "github.com/uber/cadence/common/quotas" "github.com/uber/cadence/common/service" @@ -399,11 +401,17 @@ func (f *factoryImpl) newDBVisibilityManager( result = ratelimited.NewVisibilityManager(result, ds.ratelimit) } if visibilityConfig.EnableDBVisibilitySampling != nil && visibilityConfig.EnableDBVisibilitySampling() { - result = p.NewVisibilitySamplingClient(result, &p.SamplingConfig{ - VisibilityClosedMaxQPS: visibilityConfig.WriteDBVisibilityClosedMaxQPS, - VisibilityListMaxQPS: visibilityConfig.DBVisibilityListMaxQPS, - VisibilityOpenMaxQPS: visibilityConfig.WriteDBVisibilityOpenMaxQPS, - }, f.metricsClient, f.logger) + result = sampled.NewVisibilityManager(result, sampled.Params{ + Config: &sampled.Config{ + VisibilityClosedMaxQPS: visibilityConfig.WriteDBVisibilityClosedMaxQPS, + VisibilityListMaxQPS: visibilityConfig.DBVisibilityListMaxQPS, + VisibilityOpenMaxQPS: visibilityConfig.WriteDBVisibilityOpenMaxQPS, + }, + MetricClient: f.metricsClient, + Logger: f.logger, + TimeSource: clock.NewRealTimeSource(), + RateLimiterFactoryFunc: sampled.NewDomainToBucketMap, + }) } if f.metricsClient != nil { result = p.NewVisibilityPersistenceMetricsClient(result, f.metricsClient, f.logger, f.config) diff --git a/common/persistence/persistence-tests/visibilitySamplingClient_test.go b/common/persistence/persistence-tests/visibilitySamplingClient_test.go index 5afedad55c8..baa97dc7687 100644 --- a/common/persistence/persistence-tests/visibilitySamplingClient_test.go +++ b/common/persistence/persistence-tests/visibilitySamplingClient_test.go @@ -29,12 +29,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/metrics" mmocks "github.com/uber/cadence/common/metrics/mocks" "github.com/uber/cadence/common/mocks" p "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/wrappers/sampled" "github.com/uber/cadence/common/types" ) @@ -66,13 +68,19 @@ func (s *VisibilitySamplingSuite) SetupTest() { s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil s.persistence = &mocks.VisibilityManager{} - config := &p.SamplingConfig{ + config := &sampled.Config{ VisibilityOpenMaxQPS: dynamicconfig.GetIntPropertyFilteredByDomain(1), VisibilityClosedMaxQPS: dynamicconfig.GetIntPropertyFilteredByDomain(10), VisibilityListMaxQPS: dynamicconfig.GetIntPropertyFilteredByDomain(1), } s.metricClient = &mmocks.Client{} - s.client = p.NewVisibilitySamplingClient(s.persistence, config, s.metricClient, testlogger.New(s.T())) + s.client = sampled.NewVisibilityManager(s.persistence, sampled.Params{ + Config: config, + MetricClient: s.metricClient, + Logger: testlogger.New(s.T()), + TimeSource: clock.NewRealTimeSource(), + RateLimiterFactoryFunc: sampled.NewDomainToBucketMap, + }) } func (s *VisibilitySamplingSuite) TearDownTest() { diff --git a/common/persistence/wrappers/sampled/tokenbucketfactory.go b/common/persistence/wrappers/sampled/tokenbucketfactory.go new file mode 100644 index 00000000000..5202bd48723 --- /dev/null +++ b/common/persistence/wrappers/sampled/tokenbucketfactory.go @@ -0,0 +1,75 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package sampled + +import ( + "sync" + + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/tokenbucket" +) + +type RateLimiterFactoryFunc func(timeSource clock.TimeSource, numOfPriority int, qpsConfig dynamicconfig.IntPropertyFnWithDomainFilter) RateLimiterFactory + +type RateLimiterFactory interface { + GetRateLimiter(domain string) tokenbucket.PriorityTokenBucket +} + +type domainToBucketMap struct { + sync.RWMutex + timeSource clock.TimeSource + qpsConfig dynamicconfig.IntPropertyFnWithDomainFilter + numOfPriority int + mappings map[string]tokenbucket.PriorityTokenBucket +} + +// NewDomainToBucketMap returns a rate limiter factory. +func NewDomainToBucketMap(timeSource clock.TimeSource, numOfPriority int, qpsConfig dynamicconfig.IntPropertyFnWithDomainFilter) RateLimiterFactory { + return &domainToBucketMap{ + timeSource: timeSource, + qpsConfig: qpsConfig, + numOfPriority: numOfPriority, + mappings: make(map[string]tokenbucket.PriorityTokenBucket), + } +} + +func (m *domainToBucketMap) GetRateLimiter(domain string) tokenbucket.PriorityTokenBucket { + m.RLock() + rateLimiter, exist := m.mappings[domain] + m.RUnlock() + + if exist { + return rateLimiter + } + + m.Lock() + if rateLimiter, ok := m.mappings[domain]; ok { // read again to ensure no duplicate create + m.Unlock() + return rateLimiter + } + rateLimiter = tokenbucket.NewFullPriorityTokenBucket(m.numOfPriority, m.qpsConfig(domain), m.timeSource) + m.mappings[domain] = rateLimiter + m.Unlock() + return rateLimiter +} diff --git a/common/persistence/wrappers/sampled/tokenbucketfactory_test.go b/common/persistence/wrappers/sampled/tokenbucketfactory_test.go new file mode 100644 index 00000000000..6e769acf301 --- /dev/null +++ b/common/persistence/wrappers/sampled/tokenbucketfactory_test.go @@ -0,0 +1,42 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package sampled + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/dynamicconfig" +) + +func TestDomainToBucketMap(t *testing.T) { + mockedTime := clock.NewMockedTimeSource() + factory := NewDomainToBucketMap(mockedTime, 1, dynamicconfig.GetIntPropertyFilteredByDomain(1)) + + // Test that the factory returns the same bucket for the same domain + bucket1 := factory.GetRateLimiter("domain1") + bucket2 := factory.GetRateLimiter("domain1") + assert.Equal(t, bucket1, bucket2, "domain bucket should return the same bucket for the same domain") +} diff --git a/common/persistence/visibilitySamplingClient.go b/common/persistence/wrappers/sampled/visibility_manager.go similarity index 53% rename from common/persistence/visibilitySamplingClient.go rename to common/persistence/wrappers/sampled/visibility_manager.go index 68c96471009..395b688bf72 100644 --- a/common/persistence/visibilitySamplingClient.go +++ b/common/persistence/wrappers/sampled/visibility_manager.go @@ -18,19 +18,17 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -package persistence +package sampled import ( "context" - "runtime" - "sync" "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/metrics" - "github.com/uber/cadence/common/tokenbucket" + "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/types" ) @@ -44,21 +42,18 @@ const ( // errPersistenceLimitExceededForList is the error indicating QPS limit reached for list visibility. var errPersistenceLimitExceededForList = &types.ServiceBusyError{Message: "Persistence Max QPS Reached for List Operations."} -type visibilitySamplingClient struct { - rateLimitersForOpen *domainToBucketMap - rateLimitersForClosed *domainToBucketMap - rateLimitersForList *domainToBucketMap - persistence VisibilityManager - config *SamplingConfig +type visibilityManager struct { + rateLimitersForOpen RateLimiterFactory + rateLimitersForClosed RateLimiterFactory + rateLimitersForList RateLimiterFactory + persistence persistence.VisibilityManager metricClient metrics.Client logger log.Logger } -var _ VisibilityManager = (*visibilitySamplingClient)(nil) - type ( - // SamplingConfig is config for visibility - SamplingConfig struct { + // Config is config for visibility + Config struct { VisibilityOpenMaxQPS dynamicconfig.IntPropertyFnWithDomainFilter `yaml:"-" json:"-"` // VisibilityClosedMaxQPS max QPS for record closed workflows VisibilityClosedMaxQPS dynamicconfig.IntPropertyFnWithDomainFilter `yaml:"-" json:"-"` @@ -67,61 +62,37 @@ type ( } ) -// NewVisibilitySamplingClient creates a client to manage visibility with sampling +type Params struct { + Config *Config + MetricClient metrics.Client + Logger log.Logger + TimeSource clock.TimeSource + RateLimiterFactoryFunc RateLimiterFactoryFunc +} + +// NewVisibilityManager creates a client to manage visibility with sampling // For write requests, it will do sampling which will lose some records // For read requests, it will do sampling which will return service busy errors. // Note that this is different from NewVisibilityPersistenceRateLimitedClient which is overlapping with the read processing. -func NewVisibilitySamplingClient(persistence VisibilityManager, config *SamplingConfig, metricClient metrics.Client, logger log.Logger) VisibilityManager { - return &visibilitySamplingClient{ +func NewVisibilityManager(persistence persistence.VisibilityManager, p Params) persistence.VisibilityManager { + return &visibilityManager{ persistence: persistence, - rateLimitersForOpen: newDomainToBucketMap(), - rateLimitersForClosed: newDomainToBucketMap(), - rateLimitersForList: newDomainToBucketMap(), - config: config, - metricClient: metricClient, - logger: logger, - } -} - -type domainToBucketMap struct { - sync.RWMutex - mappings map[string]tokenbucket.PriorityTokenBucket -} - -func newDomainToBucketMap() *domainToBucketMap { - return &domainToBucketMap{ - mappings: make(map[string]tokenbucket.PriorityTokenBucket), - } -} - -func (m *domainToBucketMap) getRateLimiter(domain string, numOfPriority, qps int) tokenbucket.PriorityTokenBucket { - m.RLock() - rateLimiter, exist := m.mappings[domain] - m.RUnlock() - - if exist { - return rateLimiter - } - - m.Lock() - if rateLimiter, ok := m.mappings[domain]; ok { // read again to ensure no duplicate create - m.Unlock() - return rateLimiter + rateLimitersForOpen: p.RateLimiterFactoryFunc(p.TimeSource, numOfPriorityForOpen, p.Config.VisibilityOpenMaxQPS), + rateLimitersForClosed: p.RateLimiterFactoryFunc(p.TimeSource, numOfPriorityForClosed, p.Config.VisibilityClosedMaxQPS), + rateLimitersForList: p.RateLimiterFactoryFunc(p.TimeSource, numOfPriorityForList, p.Config.VisibilityListMaxQPS), + metricClient: p.MetricClient, + logger: p.Logger, } - rateLimiter = tokenbucket.NewFullPriorityTokenBucket(numOfPriority, qps, clock.NewRealTimeSource()) - m.mappings[domain] = rateLimiter - m.Unlock() - return rateLimiter } -func (p *visibilitySamplingClient) RecordWorkflowExecutionStarted( +func (p *visibilityManager) RecordWorkflowExecutionStarted( ctx context.Context, - request *RecordWorkflowExecutionStartedRequest, + request *persistence.RecordWorkflowExecutionStartedRequest, ) error { domain := request.Domain domainID := request.DomainUUID - rateLimiter := p.rateLimitersForOpen.getRateLimiter(domain, numOfPriorityForOpen, p.config.VisibilityOpenMaxQPS(domain)) + rateLimiter := p.rateLimitersForOpen.GetRateLimiter(domain) if ok, _ := rateLimiter.GetToken(0, 1); ok { return p.persistence.RecordWorkflowExecutionStarted(ctx, request) } @@ -137,15 +108,15 @@ func (p *visibilitySamplingClient) RecordWorkflowExecutionStarted( return nil } -func (p *visibilitySamplingClient) RecordWorkflowExecutionClosed( +func (p *visibilityManager) RecordWorkflowExecutionClosed( ctx context.Context, - request *RecordWorkflowExecutionClosedRequest, + request *persistence.RecordWorkflowExecutionClosedRequest, ) error { domain := request.Domain domainID := request.DomainUUID priority := getRequestPriority(request) - rateLimiter := p.rateLimitersForClosed.getRateLimiter(domain, numOfPriorityForClosed, p.config.VisibilityClosedMaxQPS(domain)) + rateLimiter := p.rateLimitersForClosed.GetRateLimiter(domain) if ok, _ := rateLimiter.GetToken(priority, 1); ok { return p.persistence.RecordWorkflowExecutionClosed(ctx, request) } @@ -161,21 +132,14 @@ func (p *visibilitySamplingClient) RecordWorkflowExecutionClosed( return nil } -func (p *visibilitySamplingClient) RecordWorkflowExecutionUninitialized( +func (p *visibilityManager) UpsertWorkflowExecution( ctx context.Context, - request *RecordWorkflowExecutionUninitializedRequest, -) error { - return p.persistence.RecordWorkflowExecutionUninitialized(ctx, request) -} - -func (p *visibilitySamplingClient) UpsertWorkflowExecution( - ctx context.Context, - request *UpsertWorkflowExecutionRequest, + request *persistence.UpsertWorkflowExecutionRequest, ) error { domain := request.Domain domainID := request.DomainUUID - rateLimiter := p.rateLimitersForClosed.getRateLimiter(domain, numOfPriorityForClosed, p.config.VisibilityClosedMaxQPS(domain)) + rateLimiter := p.rateLimitersForClosed.GetRateLimiter(domain) if ok, _ := rateLimiter.GetToken(0, 1); ok { return p.persistence.UpsertWorkflowExecution(ctx, request) } @@ -191,134 +155,141 @@ func (p *visibilitySamplingClient) UpsertWorkflowExecution( return nil } -func (p *visibilitySamplingClient) ListOpenWorkflowExecutions( +func (p *visibilityManager) ListOpenWorkflowExecutions( ctx context.Context, - request *ListWorkflowExecutionsRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListOpenWorkflowExecutions"); err != nil { return nil, err } return p.persistence.ListOpenWorkflowExecutions(ctx, request) } -func (p *visibilitySamplingClient) ListClosedWorkflowExecutions( +func (p *visibilityManager) ListClosedWorkflowExecutions( ctx context.Context, - request *ListWorkflowExecutionsRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListClosedWorkflowExecutions"); err != nil { return nil, err } return p.persistence.ListClosedWorkflowExecutions(ctx, request) } -func (p *visibilitySamplingClient) ListOpenWorkflowExecutionsByType( +func (p *visibilityManager) ListOpenWorkflowExecutionsByType( ctx context.Context, - request *ListWorkflowExecutionsByTypeRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsByTypeRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListOpenWorkflowExecutionsByType"); err != nil { return nil, err } return p.persistence.ListOpenWorkflowExecutionsByType(ctx, request) } -func (p *visibilitySamplingClient) ListClosedWorkflowExecutionsByType( +func (p *visibilityManager) ListClosedWorkflowExecutionsByType( ctx context.Context, - request *ListWorkflowExecutionsByTypeRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsByTypeRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListClosedWorkflowExecutionsByType"); err != nil { return nil, err } return p.persistence.ListClosedWorkflowExecutionsByType(ctx, request) } -func (p *visibilitySamplingClient) ListOpenWorkflowExecutionsByWorkflowID( +func (p *visibilityManager) ListOpenWorkflowExecutionsByWorkflowID( ctx context.Context, - request *ListWorkflowExecutionsByWorkflowIDRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsByWorkflowIDRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListOpenWorkflowExecutionsByWorkflowID"); err != nil { return nil, err } return p.persistence.ListOpenWorkflowExecutionsByWorkflowID(ctx, request) } -func (p *visibilitySamplingClient) ListClosedWorkflowExecutionsByWorkflowID( +func (p *visibilityManager) ListClosedWorkflowExecutionsByWorkflowID( ctx context.Context, - request *ListWorkflowExecutionsByWorkflowIDRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListWorkflowExecutionsByWorkflowIDRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListClosedWorkflowExecutionsByWorkflowID"); err != nil { return nil, err } return p.persistence.ListClosedWorkflowExecutionsByWorkflowID(ctx, request) } -func (p *visibilitySamplingClient) ListClosedWorkflowExecutionsByStatus( +func (p *visibilityManager) ListClosedWorkflowExecutionsByStatus( ctx context.Context, - request *ListClosedWorkflowExecutionsByStatusRequest, -) (*ListWorkflowExecutionsResponse, error) { - if err := p.tryConsumeListToken(request.Domain); err != nil { + request *persistence.ListClosedWorkflowExecutionsByStatusRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { + if err := p.tryConsumeListToken(request.Domain, "ListClosedWorkflowExecutionsByStatus"); err != nil { return nil, err } return p.persistence.ListClosedWorkflowExecutionsByStatus(ctx, request) } -func (p *visibilitySamplingClient) GetClosedWorkflowExecution( +func (p *visibilityManager) RecordWorkflowExecutionUninitialized( ctx context.Context, - request *GetClosedWorkflowExecutionRequest, -) (*GetClosedWorkflowExecutionResponse, error) { + request *persistence.RecordWorkflowExecutionUninitializedRequest, +) error { + return p.persistence.RecordWorkflowExecutionUninitialized(ctx, request) +} + +func (p *visibilityManager) GetClosedWorkflowExecution( + ctx context.Context, + request *persistence.GetClosedWorkflowExecutionRequest, +) (*persistence.GetClosedWorkflowExecutionResponse, error) { return p.persistence.GetClosedWorkflowExecution(ctx, request) } -func (p *visibilitySamplingClient) DeleteWorkflowExecution( +func (p *visibilityManager) DeleteWorkflowExecution( ctx context.Context, - request *VisibilityDeleteWorkflowExecutionRequest, + request *persistence.VisibilityDeleteWorkflowExecutionRequest, ) error { return p.persistence.DeleteWorkflowExecution(ctx, request) } -func (p *visibilitySamplingClient) DeleteUninitializedWorkflowExecution( +func (p *visibilityManager) DeleteUninitializedWorkflowExecution( ctx context.Context, - request *VisibilityDeleteWorkflowExecutionRequest, + request *persistence.VisibilityDeleteWorkflowExecutionRequest, ) error { return p.persistence.DeleteUninitializedWorkflowExecution(ctx, request) } -func (p *visibilitySamplingClient) ListWorkflowExecutions( +func (p *visibilityManager) ListWorkflowExecutions( ctx context.Context, - request *ListWorkflowExecutionsByQueryRequest, -) (*ListWorkflowExecutionsResponse, error) { + request *persistence.ListWorkflowExecutionsByQueryRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { return p.persistence.ListWorkflowExecutions(ctx, request) } -func (p *visibilitySamplingClient) ScanWorkflowExecutions( +func (p *visibilityManager) ScanWorkflowExecutions( ctx context.Context, - request *ListWorkflowExecutionsByQueryRequest, -) (*ListWorkflowExecutionsResponse, error) { + request *persistence.ListWorkflowExecutionsByQueryRequest, +) (*persistence.ListWorkflowExecutionsResponse, error) { return p.persistence.ScanWorkflowExecutions(ctx, request) } -func (p *visibilitySamplingClient) CountWorkflowExecutions( +func (p *visibilityManager) CountWorkflowExecutions( ctx context.Context, - request *CountWorkflowExecutionsRequest, -) (*CountWorkflowExecutionsResponse, error) { + request *persistence.CountWorkflowExecutionsRequest, +) (*persistence.CountWorkflowExecutionsResponse, error) { return p.persistence.CountWorkflowExecutions(ctx, request) } -func (p *visibilitySamplingClient) Close() { +func (p *visibilityManager) Close() { p.persistence.Close() } -func (p *visibilitySamplingClient) GetName() string { +func (p *visibilityManager) GetName() string { return p.persistence.GetName() } -func getRequestPriority(request *RecordWorkflowExecutionClosedRequest) int { +func getRequestPriority(request *persistence.RecordWorkflowExecutionClosedRequest) int { priority := 0 if request.Status == types.WorkflowExecutionCloseStatusCompleted { priority = 1 // low priority for completed workflows @@ -326,22 +297,13 @@ func getRequestPriority(request *RecordWorkflowExecutionClosedRequest) int { return priority } -func (p *visibilitySamplingClient) tryConsumeListToken(domain string) error { - rateLimiter := p.rateLimitersForList.getRateLimiter(domain, numOfPriorityForList, p.config.VisibilityListMaxQPS(domain)) +func (p *visibilityManager) tryConsumeListToken(domain, method string) error { + rateLimiter := p.rateLimitersForList.GetRateLimiter(domain) ok, _ := rateLimiter.GetToken(0, 1) if ok { - p.logger.Debug("List API request consumed QPS token", tag.WorkflowDomainName(domain), tag.Name(callerFuncName(2))) + p.logger.Debug("List API request consumed QPS token", tag.WorkflowDomainName(domain), tag.Name(method)) return nil } - p.logger.Debug("List API request is being sampled", tag.WorkflowDomainName(domain), tag.Name(callerFuncName(2))) + p.logger.Debug("List API request is being sampled", tag.WorkflowDomainName(domain), tag.Name(method)) return errPersistenceLimitExceededForList } - -func callerFuncName(skip int) string { - pc, _, _, ok := runtime.Caller(skip) - details := runtime.FuncForPC(pc) - if ok && details != nil { - return details.Name() - } - return "" -} diff --git a/common/persistence/wrappers/sampled/visibility_manager_test.go b/common/persistence/wrappers/sampled/visibility_manager_test.go new file mode 100644 index 00000000000..e510b626e6b --- /dev/null +++ b/common/persistence/wrappers/sampled/visibility_manager_test.go @@ -0,0 +1,265 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package sampled + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/tokenbucket" + "github.com/uber/cadence/common/types" +) + +func TestVisibilityManagerSampledCalls(t *testing.T) { + for _, tc := range []struct { + name string + priority int + prepareMock func(*persistence.MockVisibilityManager) + operation func(context.Context, string, persistence.VisibilityManager) error + }{ + { + name: "RecordWorkflowExecutionStarted", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + return m.RecordWorkflowExecutionStarted(ctx, &persistence.RecordWorkflowExecutionStartedRequest{ + Domain: domain, + }) + }, + }, + { + name: "RecordWorkflowExecutionClosed", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + return m.RecordWorkflowExecutionClosed(ctx, &persistence.RecordWorkflowExecutionClosedRequest{ + Domain: domain, + Status: types.WorkflowExecutionCloseStatusCanceled, + }) + }, + }, + { + name: "RecordWorkflowExecutionClosed_Completed", + priority: 1, + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + return m.RecordWorkflowExecutionClosed(ctx, &persistence.RecordWorkflowExecutionClosedRequest{ + Domain: domain, + Status: types.WorkflowExecutionCloseStatusCompleted, + }) + }, + }, + { + name: "UpsertWorkflowExecution", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + return m.UpsertWorkflowExecution(ctx, &persistence.UpsertWorkflowExecutionRequest{ + Domain: domain, + }) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockedManager := persistence.NewMockVisibilityManager(ctrl) + + testDomain := "domain1" + + m := NewVisibilityManager(mockedManager, Params{ + Config: &Config{}, + MetricClient: metrics.NewNoopMetricsClient(), + Logger: testlogger.New(t), + TimeSource: clock.NewMockedTimeSource(), + RateLimiterFactoryFunc: rateLimiterStubFunc(map[string]tokenbucket.PriorityTokenBucket{ + testDomain: &tokenBucketFactoryStub{tokens: map[int]int{tc.priority: 1}}, + }), + }) + + tc.prepareMock(mockedManager) + + err := tc.operation(context.Background(), testDomain, m) + assert.NoError(t, err, "first call should succeed") + + err = tc.operation(context.Background(), testDomain, m) + assert.NoError(t, err, "second call should not fail, but underlying call should be blocked by rate limiter") + }) + } +} + +func TestVisibilityManagerListOperations(t *testing.T) { + for _, tc := range []struct { + name string + priority int + prepareMock func(*persistence.MockVisibilityManager) + operation func(context.Context, string, persistence.VisibilityManager) error + }{ + { + name: "ListOpenWorkflowExecutions", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListOpenWorkflowExecutions(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListOpenWorkflowExecutions(ctx, &persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }) + return err + }, + }, + { + name: "ListClosedWorkflowExecutions", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListClosedWorkflowExecutions(ctx, &persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }) + return err + }, + }, + { + name: "ListOpenWorkflowExecutionsByType", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListOpenWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListOpenWorkflowExecutionsByType(ctx, &persistence.ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }, + }) + return err + }, + }, + { + name: "ListClosedWorkflowExecutionsByType", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListClosedWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListClosedWorkflowExecutionsByType(ctx, &persistence.ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }, + }) + return err + }, + }, + { + name: "ListOpenWorkflowExecutionsByWorkflowID", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListOpenWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListOpenWorkflowExecutionsByWorkflowID(ctx, &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }, + }) + return err + }, + }, + { + name: "ListClosedWorkflowExecutionsByWorkflowID", + prepareMock: func(mock *persistence.MockVisibilityManager) { + mock.EXPECT().ListClosedWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(&persistence.ListWorkflowExecutionsResponse{}, nil).Times(1) + }, + operation: func(ctx context.Context, domain string, m persistence.VisibilityManager) error { + _, err := m.ListClosedWorkflowExecutionsByWorkflowID(ctx, &persistence.ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: persistence.ListWorkflowExecutionsRequest{ + Domain: domain, + }, + }) + return err + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockedManager := persistence.NewMockVisibilityManager(ctrl) + + testDomain := "domain1" + + m := NewVisibilityManager(mockedManager, Params{ + Config: &Config{}, + MetricClient: metrics.NewNoopMetricsClient(), + Logger: testlogger.New(t), + TimeSource: clock.NewMockedTimeSource(), + RateLimiterFactoryFunc: rateLimiterStubFunc(map[string]tokenbucket.PriorityTokenBucket{ + testDomain: &tokenBucketFactoryStub{tokens: map[int]int{tc.priority: 1}}, + }), + }) + + tc.prepareMock(mockedManager) + + err := tc.operation(context.Background(), testDomain, m) + assert.NoError(t, err, "first call should succeed") + + err = tc.operation(context.Background(), testDomain, m) + assert.Error(t, err, "second call should fail since underlying call should be blocked by rate limiter") + }) + } +} + +func rateLimiterStubFunc(domainData map[string]tokenbucket.PriorityTokenBucket) RateLimiterFactoryFunc { + return func(timeSource clock.TimeSource, numOfPriority int, qpsConfig dynamicconfig.IntPropertyFnWithDomainFilter) RateLimiterFactory { + return rateLimiterStub{domainData} + } +} + +type rateLimiterStub struct { + data map[string]tokenbucket.PriorityTokenBucket +} + +func (r rateLimiterStub) GetRateLimiter(domain string) tokenbucket.PriorityTokenBucket { + return r.data[domain] +} + +type tokenBucketFactoryStub struct { + tokens map[int]int +} + +func (t *tokenBucketFactoryStub) GetToken(priority, count int) (bool, time.Duration) { + val := t.tokens[priority] + if count > val { + return false, time.Duration(0) + } + val -= count + t.tokens[priority] = val + return true, time.Duration(0) +}