diff --git a/client/llb/async.go b/client/llb/async.go index cadbb5ef363e..97f0f980af91 100644 --- a/client/llb/async.go +++ b/client/llb/async.go @@ -6,16 +6,12 @@ import ( "github.com/moby/buildkit/solver/pb" "github.com/moby/buildkit/util/flightcontrol" digest "github.com/opencontainers/go-digest" - "github.com/pkg/errors" ) type asyncState struct { - f func(context.Context, State, *Constraints) (State, error) - prev State - target State - set bool - err error - g flightcontrol.Group[State] + f func(context.Context, State, *Constraints) (State, error) + prev State + g flightcontrol.CachedGroup[State] } func (as *asyncState) Output() Output { @@ -23,59 +19,33 @@ func (as *asyncState) Output() Output { } func (as *asyncState) Vertex(ctx context.Context, c *Constraints) Vertex { - err := as.Do(ctx, c) + target, err := as.Do(ctx, c) if err != nil { return &errVertex{err} } - if as.set { - out := as.target.Output() - if out == nil { - return nil - } - return out.Vertex(ctx, c) + out := target.Output() + if out == nil { + return nil } - return nil + return out.Vertex(ctx, c) } func (as *asyncState) ToInput(ctx context.Context, c *Constraints) (*pb.Input, error) { - err := as.Do(ctx, c) + target, err := as.Do(ctx, c) if err != nil { return nil, err } - if as.set { - out := as.target.Output() - if out == nil { - return nil, nil - } - return out.ToInput(ctx, c) + out := target.Output() + if out == nil { + return nil, nil } - return nil, nil + return out.ToInput(ctx, c) } -func (as *asyncState) Do(ctx context.Context, c *Constraints) error { - _, err := as.g.Do(ctx, "", func(ctx context.Context) (State, error) { - if as.set { - return as.target, as.err - } - res, err := as.f(ctx, as.prev, c) - if err != nil { - select { - case <-ctx.Done(): - if errors.Is(err, context.Cause(ctx)) { - return res, err - } - default: - } - } - as.target = res - as.err = err - as.set = true - return res, err +func (as *asyncState) Do(ctx context.Context, c *Constraints) (State, error) { + return as.g.Do(ctx, "", func(ctx context.Context) (State, error) { + return as.f(ctx, as.prev, c) }) - if err != nil { - return err - } - return as.err } type errVertex struct { diff --git a/client/llb/state.go b/client/llb/state.go index 1637e41770d6..056df02730d1 100644 --- a/client/llb/state.go +++ b/client/llb/state.go @@ -104,11 +104,11 @@ func (s State) getValue(k interface{}) func(context.Context, *Constraints) (inte } if s.async != nil { return func(ctx context.Context, c *Constraints) (interface{}, error) { - err := s.async.Do(ctx, c) + target, err := s.async.Do(ctx, c) if err != nil { return nil, err } - return s.async.target.getValue(k)(ctx, c) + return target.getValue(k)(ctx, c) } } if s.prev == nil { @@ -118,8 +118,13 @@ func (s State) getValue(k interface{}) func(context.Context, *Constraints) (inte } func (s State) Async(f func(context.Context, State, *Constraints) (State, error)) State { + as := &asyncState{ + f: f, + prev: s, + } + as.g.CacheError = true s2 := State{ - async: &asyncState{f: f, prev: s}, + async: as, } return s2 } diff --git a/frontend/dockerui/config.go b/frontend/dockerui/config.go index d99b3affa35f..7dc210592408 100644 --- a/frontend/dockerui/config.go +++ b/frontend/dockerui/config.go @@ -78,8 +78,7 @@ type Client struct { Config client client.Client ignoreCache []string - bctx *buildContext - g flightcontrol.Group[*buildContext] + g flightcontrol.CachedGroup[*buildContext] bopts client.BuildOpts dockerignore []byte @@ -288,14 +287,7 @@ func (bc *Client) init() error { func (bc *Client) buildContext(ctx context.Context) (*buildContext, error) { return bc.g.Do(ctx, "initcontext", func(ctx context.Context) (*buildContext, error) { - if bc.bctx != nil { - return bc.bctx, nil - } - bctx, err := bc.initContext(ctx) - if err == nil { - bc.bctx = bctx - } - return bctx, err + return bc.initContext(ctx) }) } diff --git a/util/flightcontrol/cached.go b/util/flightcontrol/cached.go new file mode 100644 index 000000000000..aeaace7514ff --- /dev/null +++ b/util/flightcontrol/cached.go @@ -0,0 +1,63 @@ +package flightcontrol + +import ( + "context" + "sync" + + "github.com/pkg/errors" +) + +// Group is a flightcontrol synchronization group that memoizes the results of a function +// and returns the cached result if the function is called with the same key. +// Don't use with long-running groups as the results are cached indefinitely. +type CachedGroup[T any] struct { + // CacheError defines if error results should also be cached. + // It is not safe to change this value after the first call to Do. + // Context cancellation errors are never cached. + CacheError bool + g Group[T] + mu sync.Mutex + cache map[string]result[T] +} + +type result[T any] struct { + v T + err error +} + +// Do executes a context function syncronized by the key or returns the cached result for the key. +func (g *CachedGroup[T]) Do(ctx context.Context, key string, fn func(ctx context.Context) (T, error)) (T, error) { + return g.g.Do(ctx, key, func(ctx context.Context) (T, error) { + g.mu.Lock() + if v, ok := g.cache[key]; ok { + g.mu.Unlock() + if v.err != nil { + if g.CacheError { + return v.v, v.err + } + } else { + return v.v, nil + } + } + g.mu.Unlock() + v, err := fn(ctx) + if err != nil { + select { + case <-ctx.Done(): + if errors.Is(err, context.Cause(ctx)) { + return v, err + } + default: + } + } + if err == nil || g.CacheError { + g.mu.Lock() + if g.cache == nil { + g.cache = make(map[string]result[T]) + } + g.cache[key] = result[T]{v: v, err: err} + g.mu.Unlock() + } + return v, err + }) +} diff --git a/util/flightcontrol/cached_test.go b/util/flightcontrol/cached_test.go new file mode 100644 index 000000000000..9ccce55b35b5 --- /dev/null +++ b/util/flightcontrol/cached_test.go @@ -0,0 +1,97 @@ +package flightcontrol + +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestCached(t *testing.T) { + var g CachedGroup[int] + + ctx := context.TODO() + + v, err := g.Do(ctx, "11", func(ctx context.Context) (int, error) { + return 1, nil + }) + require.NoError(t, err) + require.Equal(t, 1, v) + + v, err = g.Do(ctx, "22", func(ctx context.Context) (int, error) { + return 2, nil + }) + require.NoError(t, err) + require.Equal(t, 2, v) + + didCall := false + v, err = g.Do(ctx, "11", func(ctx context.Context) (int, error) { + didCall = true + return 3, nil + }) + require.NoError(t, err) + require.Equal(t, 1, v) + require.Equal(t, false, didCall) + + // by default, errors are not cached + _, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) { + return 0, errors.Errorf("some error") + }) + + require.Error(t, err) + require.ErrorContains(t, err, "some error") + + v, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) { + return 3, nil + }) + + require.NoError(t, err) + require.Equal(t, 3, v) +} + +func TestCachedError(t *testing.T) { + var g CachedGroup[string] + g.CacheError = true + + ctx := context.TODO() + + _, err := g.Do(ctx, "11", func(ctx context.Context) (string, error) { + return "", errors.Errorf("first error") + }) + require.Error(t, err) + require.ErrorContains(t, err, "first error") + + _, err = g.Do(ctx, "11", func(ctx context.Context) (string, error) { + return "never-ran", nil + }) + require.Error(t, err) + require.ErrorContains(t, err, "first error") + + // context errors are never cached + ctx, cancel := context.WithTimeoutCause(context.TODO(), 10*time.Millisecond, nil) + defer cancel() + _, err = g.Do(ctx, "22", func(ctx context.Context) (string, error) { + select { + case <-ctx.Done(): + return "", context.Cause(ctx) + case <-time.After(10 * time.Second): + return "", errors.Errorf("unexpected error") + } + }) + require.Error(t, err) + require.ErrorContains(t, err, "context deadline exceeded") + + select { + case <-ctx.Done(): + default: + require.Fail(t, "expected context to be done") + } + + v, err := g.Do(ctx, "22", func(ctx context.Context) (string, error) { + return "did-run", nil + }) + require.NoError(t, err) + require.Equal(t, "did-run", v) +}