Skip to content

Commit

Permalink
Merge pull request #5022 from tonistiigi/flightcontrol-cachedgroup
Browse files Browse the repository at this point in the history
flightcontrol: add cachedgroup struct
  • Loading branch information
AkihiroSuda authored Jun 19, 2024
2 parents 5f130fa + 9f66e2a commit 01d7739
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 59 deletions.
62 changes: 16 additions & 46 deletions client/llb/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,76 +6,46 @@ import (
"github.com/moby/buildkit/solver/pb"
"github.com/moby/buildkit/util/flightcontrol"
digest "github.com/opencontainers/go-digest"
"github.com/pkg/errors"
)

type asyncState struct {
f func(context.Context, State, *Constraints) (State, error)
prev State
target State
set bool
err error
g flightcontrol.Group[State]
f func(context.Context, State, *Constraints) (State, error)
prev State
g flightcontrol.CachedGroup[State]
}

func (as *asyncState) Output() Output {
return as
}

func (as *asyncState) Vertex(ctx context.Context, c *Constraints) Vertex {
err := as.Do(ctx, c)
target, err := as.Do(ctx, c)
if err != nil {
return &errVertex{err}
}
if as.set {
out := as.target.Output()
if out == nil {
return nil
}
return out.Vertex(ctx, c)
out := target.Output()
if out == nil {
return nil
}
return nil
return out.Vertex(ctx, c)
}

func (as *asyncState) ToInput(ctx context.Context, c *Constraints) (*pb.Input, error) {
err := as.Do(ctx, c)
target, err := as.Do(ctx, c)
if err != nil {
return nil, err
}
if as.set {
out := as.target.Output()
if out == nil {
return nil, nil
}
return out.ToInput(ctx, c)
out := target.Output()
if out == nil {
return nil, nil
}
return nil, nil
return out.ToInput(ctx, c)
}

func (as *asyncState) Do(ctx context.Context, c *Constraints) error {
_, err := as.g.Do(ctx, "", func(ctx context.Context) (State, error) {
if as.set {
return as.target, as.err
}
res, err := as.f(ctx, as.prev, c)
if err != nil {
select {
case <-ctx.Done():
if errors.Is(err, context.Cause(ctx)) {
return res, err
}
default:
}
}
as.target = res
as.err = err
as.set = true
return res, err
func (as *asyncState) Do(ctx context.Context, c *Constraints) (State, error) {
return as.g.Do(ctx, "", func(ctx context.Context) (State, error) {
return as.f(ctx, as.prev, c)
})
if err != nil {
return err
}
return as.err
}

type errVertex struct {
Expand Down
11 changes: 8 additions & 3 deletions client/llb/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ func (s State) getValue(k interface{}) func(context.Context, *Constraints) (inte
}
if s.async != nil {
return func(ctx context.Context, c *Constraints) (interface{}, error) {
err := s.async.Do(ctx, c)
target, err := s.async.Do(ctx, c)
if err != nil {
return nil, err
}
return s.async.target.getValue(k)(ctx, c)
return target.getValue(k)(ctx, c)
}
}
if s.prev == nil {
Expand All @@ -119,8 +119,13 @@ func (s State) getValue(k interface{}) func(context.Context, *Constraints) (inte
}

func (s State) Async(f func(context.Context, State, *Constraints) (State, error)) State {
as := &asyncState{
f: f,
prev: s,
}
as.g.CacheError = true
s2 := State{
async: &asyncState{f: f, prev: s},
async: as,
}
return s2
}
Expand Down
12 changes: 2 additions & 10 deletions frontend/dockerui/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ type Client struct {
Config
client client.Client
ignoreCache []string
bctx *buildContext
g flightcontrol.Group[*buildContext]
g flightcontrol.CachedGroup[*buildContext]
bopts client.BuildOpts

dockerignore []byte
Expand Down Expand Up @@ -288,14 +287,7 @@ func (bc *Client) init() error {

func (bc *Client) buildContext(ctx context.Context) (*buildContext, error) {
return bc.g.Do(ctx, "initcontext", func(ctx context.Context) (*buildContext, error) {
if bc.bctx != nil {
return bc.bctx, nil
}
bctx, err := bc.initContext(ctx)
if err == nil {
bc.bctx = bctx
}
return bctx, err
return bc.initContext(ctx)
})
}

Expand Down
63 changes: 63 additions & 0 deletions util/flightcontrol/cached.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package flightcontrol

import (
"context"
"sync"

"github.com/pkg/errors"
)

// Group is a flightcontrol synchronization group that memoizes the results of a function
// and returns the cached result if the function is called with the same key.
// Don't use with long-running groups as the results are cached indefinitely.
type CachedGroup[T any] struct {
// CacheError defines if error results should also be cached.
// It is not safe to change this value after the first call to Do.
// Context cancellation errors are never cached.
CacheError bool
g Group[T]
mu sync.Mutex
cache map[string]result[T]
}

type result[T any] struct {
v T
err error
}

// Do executes a context function syncronized by the key or returns the cached result for the key.
func (g *CachedGroup[T]) Do(ctx context.Context, key string, fn func(ctx context.Context) (T, error)) (T, error) {
return g.g.Do(ctx, key, func(ctx context.Context) (T, error) {
g.mu.Lock()
if v, ok := g.cache[key]; ok {
g.mu.Unlock()
if v.err != nil {
if g.CacheError {
return v.v, v.err
}
} else {
return v.v, nil
}
}
g.mu.Unlock()
v, err := fn(ctx)
if err != nil {
select {
case <-ctx.Done():
if errors.Is(err, context.Cause(ctx)) {
return v, err
}
default:
}
}
if err == nil || g.CacheError {
g.mu.Lock()
if g.cache == nil {
g.cache = make(map[string]result[T])
}
g.cache[key] = result[T]{v: v, err: err}
g.mu.Unlock()
}
return v, err
})
}
97 changes: 97 additions & 0 deletions util/flightcontrol/cached_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package flightcontrol

import (
"context"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)

func TestCached(t *testing.T) {
var g CachedGroup[int]

ctx := context.TODO()

v, err := g.Do(ctx, "11", func(ctx context.Context) (int, error) {
return 1, nil
})
require.NoError(t, err)
require.Equal(t, 1, v)

v, err = g.Do(ctx, "22", func(ctx context.Context) (int, error) {
return 2, nil
})
require.NoError(t, err)
require.Equal(t, 2, v)

didCall := false
v, err = g.Do(ctx, "11", func(ctx context.Context) (int, error) {
didCall = true
return 3, nil
})
require.NoError(t, err)
require.Equal(t, 1, v)
require.Equal(t, false, didCall)

// by default, errors are not cached
_, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) {
return 0, errors.Errorf("some error")
})

require.Error(t, err)
require.ErrorContains(t, err, "some error")

v, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) {
return 3, nil
})

require.NoError(t, err)
require.Equal(t, 3, v)
}

func TestCachedError(t *testing.T) {
var g CachedGroup[string]
g.CacheError = true

ctx := context.TODO()

_, err := g.Do(ctx, "11", func(ctx context.Context) (string, error) {
return "", errors.Errorf("first error")
})
require.Error(t, err)
require.ErrorContains(t, err, "first error")

_, err = g.Do(ctx, "11", func(ctx context.Context) (string, error) {
return "never-ran", nil
})
require.Error(t, err)
require.ErrorContains(t, err, "first error")

// context errors are never cached
ctx, cancel := context.WithTimeoutCause(context.TODO(), 10*time.Millisecond, nil)
defer cancel()
_, err = g.Do(ctx, "22", func(ctx context.Context) (string, error) {
select {
case <-ctx.Done():
return "", context.Cause(ctx)
case <-time.After(10 * time.Second):
return "", errors.Errorf("unexpected error")
}
})
require.Error(t, err)
require.ErrorContains(t, err, "context deadline exceeded")

select {
case <-ctx.Done():
default:
require.Fail(t, "expected context to be done")
}

v, err := g.Do(ctx, "22", func(ctx context.Context) (string, error) {
return "did-run", nil
})
require.NoError(t, err)
require.Equal(t, "did-run", v)
}

0 comments on commit 01d7739

Please sign in to comment.