Skip to content

Commit

Permalink
Make cache generic to avoid casting (#3179)
Browse files Browse the repository at this point in the history
* Make cache generic to avoid casting

Signed-off-by: Sylvain Rabot <[email protected]>

* Update handler/handler.go

---------

Signed-off-by: Sylvain Rabot <[email protected]>
Co-authored-by: Steve Coffman <[email protected]>
  • Loading branch information
sylr and StevenACoffman authored Jul 12, 2024
1 parent f2cf11e commit 4d8d93c
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 61 deletions.
20 changes: 11 additions & 9 deletions graphql/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@ package graphql
import "context"

// Cache is a shared store for APQ and query AST caching
type Cache interface {
type Cache[T any] interface {
// Get looks up a key's value from the cache.
Get(ctx context.Context, key string) (value any, ok bool)
Get(ctx context.Context, key string) (value T, ok bool)

// Add adds a value to the cache.
Add(ctx context.Context, key string, value any)
Add(ctx context.Context, key string, value T)
}

// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests
type MapCache map[string]any
type MapCache[T any] map[string]T

// Get looks up a key's value from the cache.
func (m MapCache) Get(_ context.Context, key string) (value any, ok bool) {
func (m MapCache[T]) Get(_ context.Context, key string) (value T, ok bool) {
v, ok := m[key]
return v, ok
}

// Add adds a value to the cache.
func (m MapCache) Add(_ context.Context, key string, value any) { m[key] = value }
func (m MapCache[T]) Add(_ context.Context, key string, value T) { m[key] = value }

type NoCache struct{}
type NoCache[T any, T2 *T] struct{}

func (n NoCache) Get(_ context.Context, _ string) (value any, ok bool) { return nil, false }
func (n NoCache) Add(_ context.Context, _ string, _ any) {}
var _ Cache[*string] = (*NoCache[string, *string])(nil)

func (n NoCache[T, T2]) Get(_ context.Context, _ string) (value T2, ok bool) { return nil, false }
func (n NoCache[T, T2]) Add(_ context.Context, _ string, _ T2) {}
20 changes: 10 additions & 10 deletions graphql/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func TestMapCache(t *testing.T) {
t.Run("Add and Get", func(t *testing.T) {
cache := MapCache{}
cache := MapCache[string]{}
ctx := context.Background()
key := "testKey"
value := "testValue"
Expand All @@ -29,7 +29,7 @@ func TestMapCache(t *testing.T) {

func TestMapCacheMultipleEntries(t *testing.T) {
t.Run("Multiple Add and Get", func(t *testing.T) {
cache := MapCache{}
cache := MapCache[string]{}
ctx := context.Background()

// Define multiple key-value pairs
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestMapCacheEdgeCases(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cache := MapCache{}
cache := MapCache[string]{}
ctx := context.Background()

// Set initial value if needed
Expand All @@ -114,13 +114,13 @@ func TestMapCacheEdgeCases(t *testing.T) {

func TestNoCache(t *testing.T) {
t.Run("Add and Get", func(t *testing.T) {
cache := NoCache{}
cache := NoCache[string, *string]{}
ctx := context.Background()
key := "testKey"
value := "testValue"

// Test Add
cache.Add(ctx, key, value) // Should do nothing
cache.Add(ctx, key, &value) // Should do nothing

// Test Get
gotValue, ok := cache.Get(ctx, key)
Expand All @@ -131,7 +131,7 @@ func TestNoCache(t *testing.T) {

func TestNoCacheMultipleEntries(t *testing.T) {
t.Run("Multiple Add and Get", func(t *testing.T) {
cache := NoCache{}
cache := NoCache[string, *string]{}
ctx := context.Background()

// Define multiple key-value pairs
Expand All @@ -143,7 +143,7 @@ func TestNoCacheMultipleEntries(t *testing.T) {

// Test Add for multiple entries
for key, value := range entries {
cache.Add(ctx, key, value) // Should do nothing
cache.Add(ctx, key, &value) // Should do nothing
}

// Test Get for multiple entries
Expand All @@ -161,7 +161,7 @@ func TestNoCacheEdgeCases(t *testing.T) {
key string
value string
wantOk bool
wantValue any
wantValue *string
}

tests := []testCase{
Expand Down Expand Up @@ -190,11 +190,11 @@ func TestNoCacheEdgeCases(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cache := NoCache{}
cache := NoCache[string, *string]{}
ctx := context.Background()

// Test Add
cache.Add(ctx, tc.key, tc.value)
cache.Add(ctx, tc.key, &tc.value)

// Test Get
gotValue, ok := cache.Get(ctx, tc.key)
Expand Down
8 changes: 4 additions & 4 deletions graphql/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Executor struct {

errorPresenter graphql.ErrorPresenterFunc
recoverFunc graphql.RecoverFunc
queryCache graphql.Cache
queryCache graphql.Cache[*ast.QueryDocument]

parserTokenLimit int
}
Expand All @@ -36,7 +36,7 @@ func New(es graphql.ExecutableSchema) *Executor {
es: es,
errorPresenter: graphql.DefaultErrorPresenter,
recoverFunc: graphql.DefaultRecover,
queryCache: graphql.NoCache{},
queryCache: graphql.NoCache[ast.QueryDocument, *ast.QueryDocument]{},
ext: processExtensions(nil),
parserTokenLimit: parserTokenNoLimit,
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (e *Executor) PresentRecoveredError(ctx context.Context, err any) error {
return e.errorPresenter(ctx, e.recoverFunc(ctx, err))
}

func (e *Executor) SetQueryCache(cache graphql.Cache) {
func (e *Executor) SetQueryCache(cache graphql.Cache[*ast.QueryDocument]) {
e.queryCache = cache
}

Expand Down Expand Up @@ -194,7 +194,7 @@ func (e *Executor) parseQuery(

stats.Parsing.End = now
stats.Validation.Start = now
return doc.(*ast.QueryDocument), nil
return doc, nil
}

doc, err := parser.ParseQueryWithTokenLimit(&ast.Source{Input: query}, e.parserTokenLimit)
Expand Down
6 changes: 3 additions & 3 deletions graphql/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestExecutor(t *testing.T) {

t.Run("query caching", func(t *testing.T) {
ctx := context.Background()
cache := &graphql.MapCache{}
cache := &graphql.MapCache[*ast.QueryDocument]{}
exec.SetQueryCache(cache)
qry := `query Foo {name}`

Expand All @@ -151,7 +151,7 @@ func TestExecutor(t *testing.T) {

cacheDoc, ok := cache.Get(ctx, qry)
require.True(t, ok)
require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
require.Equal(t, "Foo", cacheDoc.Operations[0].Name)
})

t.Run("cache hits use document from cache", func(t *testing.T) {
Expand All @@ -164,7 +164,7 @@ func TestExecutor(t *testing.T) {

cacheDoc, ok := cache.Get(ctx, qry)
require.True(t, ok)
require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
require.Equal(t, "Bar", cacheDoc.Operations[0].Name)
})
})
}
Expand Down
4 changes: 2 additions & 2 deletions graphql/handler/apollofederatedtracingv1/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestApolloTracing_Concurrent(t *testing.T) {
func TestApolloTracing_withFail(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New[string](100)})
h.Use(&apollofederatedtracingv1.Tracer{})

resp := doRequest(h, http.MethodPost, "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`)
Expand All @@ -124,7 +124,7 @@ func TestApolloTracing_withFail(t *testing.T) {
func TestApolloTracing_withMissingOp(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New[string](100)})
h.Use(&apollofederatedtracingv1.Tracer{})

resp := doRequest(h, http.MethodPost, "/graphql", `{}`)
Expand Down
2 changes: 1 addition & 1 deletion graphql/handler/apollotracing/tracer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestApolloTracing_withFail(t *testing.T) {

h := testserver.New()
h.AddTransport(transport.POST{})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)})
h.Use(extension.AutomaticPersistedQuery{Cache: lru.New[string](100)})
h.Use(apollotracing.Tracer{})

resp := doRequest(h, http.MethodPost, "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`)
Expand Down
6 changes: 3 additions & 3 deletions graphql/handler/extension/apq.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
// hash in the next request.
// see https:/apollographql/apollo-link-persisted-queries
type AutomaticPersistedQuery struct {
Cache graphql.Cache
Cache graphql.Cache[string]
}

type ApqStats struct {
Expand Down Expand Up @@ -72,14 +72,14 @@ func (a AutomaticPersistedQuery) MutateOperationParameters(ctx context.Context,

fullQuery := false
if rawParams.Query == "" {
var ok bool
// client sent optimistic query hash without query string, get it from the cache
query, ok := a.Cache.Get(ctx, extension.Sha256)
rawParams.Query, ok = a.Cache.Get(ctx, extension.Sha256)
if !ok {
err := gqlerror.Errorf(errPersistedQueryNotFound)
errcode.Set(err, errPersistedQueryNotFoundCode)
return err
}
rawParams.Query = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rawParams.Query) != extension.Sha256 {
Expand Down
18 changes: 9 additions & 9 deletions graphql/handler/extension/apq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

func TestAPQIntegration(t *testing.T) {
h := testserver.New()
h.Use(&extension.AutomaticPersistedQuery{Cache: graphql.MapCache{}})
h.Use(&extension.AutomaticPersistedQuery{Cache: graphql.MapCache[string]{}})
h.AddTransport(&transport.POST{})

var stats *extension.ApqStats
Expand Down Expand Up @@ -44,7 +44,7 @@ func TestAPQ(t *testing.T) {
Query: "original query",
}

err := extension.AutomaticPersistedQuery{Cache: graphql.MapCache{}}.MutateOperationParameters(ctx, params)
err := extension.AutomaticPersistedQuery{Cache: graphql.MapCache[string]{}}.MutateOperationParameters(ctx, params)

require.Equal(t, (*gqlerror.Error)(nil), err)
require.Equal(t, "original query", params.Query)
Expand All @@ -61,7 +61,7 @@ func TestAPQ(t *testing.T) {
},
}

err := extension.AutomaticPersistedQuery{Cache: graphql.MapCache{}}.MutateOperationParameters(ctx, params)
err := extension.AutomaticPersistedQuery{Cache: graphql.MapCache[string]{}}.MutateOperationParameters(ctx, params)
require.Equal(t, "PersistedQueryNotFound", err.Message)
})

Expand All @@ -76,7 +76,7 @@ func TestAPQ(t *testing.T) {
},
},
}
cache := graphql.MapCache{}
cache := graphql.MapCache[string]{}
err := extension.AutomaticPersistedQuery{Cache: cache}.MutateOperationParameters(ctx, params)

require.Equal(t, (*gqlerror.Error)(nil), err)
Expand All @@ -95,7 +95,7 @@ func TestAPQ(t *testing.T) {
},
},
}
cache := graphql.MapCache{}
cache := graphql.MapCache[string]{}
err := extension.AutomaticPersistedQuery{cache}.MutateOperationParameters(ctx, params)
require.Equal(t, (*gqlerror.Error)(nil), err)

Expand All @@ -113,7 +113,7 @@ func TestAPQ(t *testing.T) {
},
},
}
cache := graphql.MapCache{
cache := graphql.MapCache[string]{
hash: query,
}
err := extension.AutomaticPersistedQuery{cache}.MutateOperationParameters(ctx, params)
Expand All @@ -130,7 +130,7 @@ func TestAPQ(t *testing.T) {
},
}

err := extension.AutomaticPersistedQuery{graphql.MapCache{}}.MutateOperationParameters(ctx, params)
err := extension.AutomaticPersistedQuery{graphql.MapCache[string]{}}.MutateOperationParameters(ctx, params)
require.Equal(t, "invalid APQ extension data", err.Message)
})

Expand All @@ -143,7 +143,7 @@ func TestAPQ(t *testing.T) {
},
},
}
err := extension.AutomaticPersistedQuery{graphql.MapCache{}}.MutateOperationParameters(ctx, params)
err := extension.AutomaticPersistedQuery{graphql.MapCache[string]{}}.MutateOperationParameters(ctx, params)
require.Equal(t, "unsupported APQ version", err.Message)
})

Expand All @@ -159,7 +159,7 @@ func TestAPQ(t *testing.T) {
},
}

err := extension.AutomaticPersistedQuery{graphql.MapCache{}}.MutateOperationParameters(ctx, params)
err := extension.AutomaticPersistedQuery{graphql.MapCache[string]{}}.MutateOperationParameters(ctx, params)
require.Equal(t, "provided APQ hash does not match query", err.Message)
})
}
Expand Down
16 changes: 8 additions & 8 deletions graphql/handler/lru/lru.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ import (
"github.com/99designs/gqlgen/graphql"
)

type LRU struct {
lru *lru.Cache[string, any]
type LRU[T any] struct {
lru *lru.Cache[string, T]
}

var _ graphql.Cache = &LRU{}
var _ graphql.Cache[any] = &LRU[any]{}

func New(size int) *LRU {
cache, err := lru.New[string, any](size)
func New[T any](size int) *LRU[T] {
cache, err := lru.New[string, T](size)
if err != nil {
// An error is only returned for non-positive cache size
// and we already checked for that.
panic("unexpected error creating cache: " + err.Error())
}
return &LRU{cache}
return &LRU[T]{cache}
}

func (l LRU) Get(ctx context.Context, key string) (value any, ok bool) {
func (l LRU[T]) Get(ctx context.Context, key string) (value T, ok bool) {
return l.lru.Get(key)
}

func (l LRU) Add(ctx context.Context, key string, value any) {
func (l LRU[T]) Add(ctx context.Context, key string, value T) {
l.lru.Add(key, value)
}
7 changes: 4 additions & 3 deletions graphql/handler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"

"github.com/99designs/gqlgen/graphql"
Expand Down Expand Up @@ -41,11 +42,11 @@ func NewDefaultServer(es graphql.ExecutableSchema) *Server {
srv.AddTransport(transport.POST{})
srv.AddTransport(transport.MultipartForm{})

srv.SetQueryCache(lru.New(1000))
srv.SetQueryCache(lru.New[*ast.QueryDocument](1000))

srv.Use(extension.Introspection{})
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New(100),
Cache: lru.New[string](100),
})

return srv
Expand All @@ -63,7 +64,7 @@ func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
s.exec.SetRecoverFunc(f)
}

func (s *Server) SetQueryCache(cache graphql.Cache) {
func (s *Server) SetQueryCache(cache graphql.Cache[*ast.QueryDocument]) {
s.exec.SetQueryCache(cache)
}

Expand Down
Loading

0 comments on commit 4d8d93c

Please sign in to comment.