Skip to content

Commit

Permalink
contenthash: unify "follow" and trailing-symlink handling for Checksum
Browse files Browse the repository at this point in the history
This patch is part of a series which fixes the symlink resolution
semantics within BuildKit.

Previously, the concept of the follow flag had different meanings in
various parts of the checksum codepath. FollowLinks is effectively
O_NOFOLLOW, but the implementation in getFollowLinks was actually more
like RESOLVE_NO_SYMLINKS. This was masked by the fact that
checksumFollow would implement the O_NOFOLLOW behaviour (incorrectly),
but checksumFollow would call checksumNoFollow (which would follow
symlinks in path components by setting follow=true for getFollowLinks).

It is much easier to simply remove these layers of indirection and unify
the meaning of FollowLinks across all of the code. This means that the
old follow flag is no longer needed.

This also means that we can now remove the incorrect symlink resolution
logic in (*cacheContext).checksumFollow() and move the followTrailing
logic to (*cacheContext).checksum(), as well as removing
getFollowParentLinks(). Since this removes some redundant re-checksum
loops, we need to add followTrailing logic to scanPath() so that final
symlink components result in the correct directory being scanned
properly.

The only user of (*cacheContext).checksum(follow=false) was
(*cacheContext).includedPaths() which appeared to be simply using this
as an optimisation (since the path being walked already had its parent
path resolved). Having two easily-confused boolean flags for an
optimisation that is probably not necessary (getFollowLinks already does
a fast check to see if the original path is in the cache) seemed
unnecessary, so just keep followTrailing.

Signed-off-by: Aleksa Sarai <[email protected]>
  • Loading branch information
cyphar committed May 2, 2024
1 parent 44b36df commit 6ef5a15
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 104 deletions.
162 changes: 59 additions & 103 deletions cache/contenthash/checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
defer m.clean()

if !opts.Wildcard && len(opts.IncludePatterns) == 0 && len(opts.ExcludePatterns) == 0 {
return cc.checksumFollow(ctx, m, p, opts.FollowLinks)
return cc.lazyChecksum(ctx, m, p, opts.FollowLinks)
}

includedPaths, err := cc.includedPaths(ctx, m, p, opts)
Expand All @@ -418,7 +418,7 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
if opts.FollowLinks {
for i, w := range includedPaths {
if w.record.Type == CacheRecordTypeSymlink {
dgst, err := cc.checksumFollow(ctx, m, w.path, opts.FollowLinks)
dgst, err := cc.lazyChecksum(ctx, m, w.path, opts.FollowLinks)
if err != nil {
return "", err
}
Expand All @@ -445,30 +445,6 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
return digester.Digest(), nil
}

func (cc *cacheContext) checksumFollow(ctx context.Context, m *mount, p string, follow bool) (digest.Digest, error) {
const maxSymlinkLimit = 255
i := 0
for {
if i > maxSymlinkLimit {
return "", errors.Errorf("too many symlinks: %s", p)
}
cr, err := cc.checksumNoFollow(ctx, m, p)
if err != nil {
return "", err
}
if cr.Type == CacheRecordTypeSymlink && follow {
link := cr.Linkname
if !path.IsAbs(cr.Linkname) {
link = path.Join(path.Dir(p), link)
}
i++
p = link
} else {
return cr.Digest, nil
}
}
}

func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, opts ChecksumOpts) ([]*includedPath, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
Expand All @@ -478,12 +454,12 @@ func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, o
}

root := cc.tree.Root()
scan, err := cc.needsScan(root, "")
scan, err := cc.needsScan(root, "", false)
if err != nil {
return nil, err
}
if scan {
if err := cc.scanPath(ctx, m, ""); err != nil {
if err := cc.scanPath(ctx, m, "", false); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -542,7 +518,7 @@ func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, o
// involves a symlink. That will match fsutil behavior of
// calling functions such as stat and walk.
var cr *CacheRecord
k, cr, err = getFollowParentLinks(root, k, true)
k, cr, err = getFollowLinks(root, k, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -753,11 +729,7 @@ func wildcardPrefix(root *iradix.Node, p string) (string, []byte, bool, error) {

// Only resolve the final symlink component if there are components in the
// wildcard segment.
resolveFn := getFollowParentLinks
if d2 != "" {
resolveFn = getFollowLinks
}
k, cr, err := resolveFn(root, convertPathToKey([]byte(d1)), true)
k, cr, err := getFollowLinks(root, convertPathToKey([]byte(d1)), d2 != "")
if err != nil {
return "", k, false, err
}
Expand Down Expand Up @@ -796,19 +768,22 @@ func containsWildcards(name string) bool {
return false
}

func (cc *cacheContext) checksumNoFollow(ctx context.Context, m *mount, p string) (*CacheRecord, error) {
func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string, followTrailing bool) (digest.Digest, error) {
p = keyPath(p)
k := convertPathToKey([]byte(p))

// Try to look up the path directly without doing a scan.
cc.mu.RLock()
if cc.txn == nil {
root := cc.tree.Root()
cc.mu.RUnlock()
v, ok := root.Get(convertPathToKey([]byte(p)))
if ok {
cr := v.(*CacheRecord)
if cr.Digest != "" {
return cr, nil
}

_, cr, err := getFollowLinks(root, k, followTrailing)
if err != nil {
return "", err
}
if cr != nil && cr.Digest != "" {
return cr.Digest, nil
}
} else {
cc.mu.RUnlock()
Expand All @@ -828,7 +803,11 @@ func (cc *cacheContext) checksumNoFollow(ctx context.Context, m *mount, p string
}
}()

return cc.lazyChecksum(ctx, m, p)
cr, err := cc.scanChecksum(ctx, m, p, followTrailing)
if err != nil {
return "", err
}
return cr.Digest, nil
}

func (cc *cacheContext) commitActiveTransaction() {
Expand All @@ -847,21 +826,21 @@ func (cc *cacheContext) commitActiveTransaction() {
cc.txn = nil
}

func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string) (*CacheRecord, error) {
func (cc *cacheContext) scanChecksum(ctx context.Context, m *mount, p string, followTrailing bool) (*CacheRecord, error) {
root := cc.tree.Root()
scan, err := cc.needsScan(root, p)
scan, err := cc.needsScan(root, p, followTrailing)
if err != nil {
return nil, err
}
if scan {
if err := cc.scanPath(ctx, m, p); err != nil {
if err := cc.scanPath(ctx, m, p, followTrailing); err != nil {
return nil, err
}
}
k := convertPathToKey([]byte(p))
txn := cc.tree.Txn()
root = txn.Root()
cr, updated, err := cc.checksum(ctx, root, txn, m, k, true)
cr, updated, err := cc.checksum(ctx, root, txn, m, k, followTrailing)
if err != nil {
return nil, err
}
Expand All @@ -870,9 +849,9 @@ func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string) (*
return cr, err
}

func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *iradix.Txn, m *mount, k []byte, follow bool) (*CacheRecord, bool, error) {
func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *iradix.Txn, m *mount, k []byte, followTrailing bool) (*CacheRecord, bool, error) {
origk := k
k, cr, err := getFollowParentLinks(root, k, follow)
k, cr, err := getFollowLinks(root, k, followTrailing)
if err != nil {
return nil, false, err
}
Expand All @@ -898,7 +877,9 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir
}
h.Write(bytes.TrimPrefix(subk, k))

subcr, _, err := cc.checksum(ctx, root, txn, m, subk, true)
// We do not follow trailing links when checksumming a directory's
// contents.
subcr, _, err := cc.checksum(ctx, root, txn, m, subk, false)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -949,13 +930,13 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir

// needsScan returns false if path is in the tree or a parent path is in tree
// and subpath is missing.
func (cc *cacheContext) needsScan(root *iradix.Node, path string) (bool, error) {
func (cc *cacheContext) needsScan(root *iradix.Node, path string, followTrailing bool) (bool, error) {
var (
lastGoodPath string
hasParentInTree bool
)
k := convertPathToKey([]byte(path))
_, cr, err := getFollowLinksCallback(root, k, true, func(subpath string, cr *CacheRecord) error {
_, cr, err := getFollowLinksCallback(root, k, followTrailing, func(subpath string, cr *CacheRecord) error {
if cr != nil {
// If the path is not a symlink, then for now we have a parent in
// the tree. Otherwise, we reset hasParentInTree because we
Expand All @@ -981,8 +962,8 @@ func (cc *cacheContext) needsScan(root *iradix.Node, path string) (bool, error)
return cr == nil && !hasParentInTree, nil
}

func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retErr error) {
d := path.Dir(path.Join("/", p))
func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string, followTrailing bool) (retErr error) {
p = path.Join("/", p)

mp, err := m.mount(ctx)
if err != nil {
Expand All @@ -992,7 +973,7 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
n := cc.tree.Root()
txn := cc.tree.Txn()

parentPath, err := rootPath(mp, filepath.FromSlash(d), func(p, link string) error {
resolvedPath, err := rootPath(mp, filepath.FromSlash(p), followTrailing, func(p, link string) error {
cr := &CacheRecord{
Type: CacheRecordTypeSymlink,
Linkname: filepath.ToSlash(link),
Expand All @@ -1006,7 +987,14 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
return err
}

err = filepath.Walk(parentPath, func(itemPath string, fi os.FileInfo, err error) error {
// Scan the parent directory of the path we resolved, unless we're at the
// root (in which case we scan the root).
scanPath := filepath.Dir(resolvedPath)
if !strings.HasPrefix(filepath.ToSlash(scanPath)+"/", filepath.ToSlash(mp)+"/") {
scanPath = resolvedPath
}

err = filepath.Walk(scanPath, func(itemPath string, fi os.FileInfo, err error) error {
if err != nil {
// If the root doesn't exist, ignore the error.
if errors.Is(err, os.ErrNotExist) {
Expand Down Expand Up @@ -1055,48 +1043,27 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
return nil
}

// getFollowParentLinks is effectively O_PATH|O_NOFOLLOW, where the final
// component is looked up without doing any symlink resolution (if it is a
// symlink).
func getFollowParentLinks(root *iradix.Node, k []byte, follow bool) ([]byte, *CacheRecord, error) {
v, ok := root.Get(k)
if ok {
return k, v.(*CacheRecord), nil
}
if !follow || len(k) == 0 {
return k, nil, nil
}

// Only fully evaluate the parent path.
dir, file := splitKey(k)
dir, _, err := getFollowLinks(root, dir, follow)
if err != nil {
return nil, nil, err
}

// Do a direct lookup of the final component.
k = append(dir, file...)
v, ok = root.Get(k)
if ok {
return k, v.(*CacheRecord), nil
}
return k, nil, nil
}

// followLinksCallback is called after we try to resolve each element. If the
// path was not found, cr is nil.
type followLinksCallback func(path string, cr *CacheRecord) error

func getFollowLinks(root *iradix.Node, k []byte, follow bool) ([]byte, *CacheRecord, error) {
return getFollowLinksCallback(root, k, follow, nil)
// getFollowLinks looks up the requested key, fully resolving any symlink
// components encountered.
//
// followTrailing indicates whether the *final component* of the path should be
// resolved (effectively O_PATH|O_NOFOLLOW). Note that (in contrast to some
// Linux APIs), followTrailing is obeyed even if the key has a trailing slash
// (though paths like "foo/link/." will cause the link to be resolved).
func getFollowLinks(root *iradix.Node, k []byte, followTrailing bool) ([]byte, *CacheRecord, error) {
return getFollowLinksCallback(root, k, followTrailing, nil)
}

func getFollowLinksCallback(root *iradix.Node, k []byte, follow bool, cb followLinksCallback) ([]byte, *CacheRecord, error) {
func getFollowLinksCallback(root *iradix.Node, k []byte, followTrailing bool, cb followLinksCallback) ([]byte, *CacheRecord, error) {
v, ok := root.Get(k)
if ok && v.(*CacheRecord).Type != CacheRecordTypeSymlink {
if ok && (!followTrailing || v.(*CacheRecord).Type != CacheRecordTypeSymlink) {
return k, v.(*CacheRecord), nil
}
if !follow || len(k) == 0 {
if len(k) == 0 {
return k, nil, nil
}

Expand Down Expand Up @@ -1146,6 +1113,10 @@ func getFollowLinksCallback(root *iradix.Node, k []byte, follow bool, cb followL
currentPath = nextPath
continue
}
if !followTrailing && remainingPath == "" {
currentPath = nextPath
break
}

linksWalked++
if linksWalked > maxSymlinkLimit {
Expand Down Expand Up @@ -1232,18 +1203,3 @@ func convertPathToKey(p []byte) []byte {
func convertKeyToPath(p []byte) []byte {
return bytes.Replace([]byte(p), []byte{0}, []byte("/"), -1)
}

func splitKey(k []byte) ([]byte, []byte) {
foundBytes := false
i := len(k) - 1
for {
if i <= 0 || foundBytes && k[i] == 0 {
break
}
if k[i] != 0 {
foundBytes = true
}
i--
}
return append([]byte{}, k[:i]...), k[i:]
}
6 changes: 5 additions & 1 deletion cache/contenthash/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type onSymlinkFunc func(string, string) error
// symlink to the root directory.
// This is github.com/cyphar/filepath-securejoin.SecureJoinVFS's implementation
// with a callback on resolving the symlink.
func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
func rootPath(root, unsafePath string, followTrailing bool, cb onSymlinkFunc) (string, error) {
if unsafePath == "" {
return root, nil
}
Expand Down Expand Up @@ -65,6 +65,10 @@ func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
path = nextPath
continue
}
if !followTrailing && unsafePath == "" {
path = nextPath
break
}

// It's a symlink, so get its contents and expand it by prepending it
// to the yet-unparsed path.
Expand Down

0 comments on commit 6ef5a15

Please sign in to comment.