Skip to content

Commit

Permalink
add support for custom deepcopy logic
Browse files Browse the repository at this point in the history
  • Loading branch information
xrstf committed Dec 2, 2023
1 parent e2d0344 commit 7bea6fb
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
8 changes: 8 additions & 0 deletions pkg/deepcopy/deepcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func MustClone[T any](val T) T {
return cloned
}

type Copier interface {
DeepCopy() (any, error)
}

func clone(val any) (any, error) {
switch asserted := val.(type) {
// Go native types
Expand Down Expand Up @@ -66,6 +70,10 @@ func clone(val any) (any, error) {
case ast.String:
return asserted, nil

// custom logic
case Copier:
return asserted.DeepCopy()

default:
return nil, fmt.Errorf("cannot deep-copy %T", val)
}
Expand Down
68 changes: 65 additions & 3 deletions pkg/deepcopy/deepcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,40 @@ import (
"github.com/google/go-cmp/cmp"
)

type customCopier struct {
Value any
}

var _ Copier = customCopier{}

func (c customCopier) DeepCopy() (any, error) {
copied, err := Clone(c.Value)
if err != nil {
return nil, err
}

return customCopier{
Value: copied,
}, nil
}

type customPtrCopier struct {
Value any
}

var _ Copier = &customPtrCopier{}

func (c *customPtrCopier) DeepCopy() (any, error) {
copied, err := Clone(c.Value)
if err != nil {
return nil, err
}

return &customPtrCopier{
Value: copied,
}, nil
}

func TestCloneScalars(t *testing.T) {
testcases := []struct {
input any
Expand Down Expand Up @@ -75,7 +109,7 @@ func TestCloneScalars(t *testing.T) {
}

if !cmp.Equal(cloned, testcase.expected) {
t.Fatalf("Unpected result:\n\n%s\n", renderDiff(testcase.expected, cloned))
t.Fatalf("Unexpected result:\n\n%s\n", renderDiff(testcase.expected, cloned))
}

if &cloned == &testcase.input {
Expand Down Expand Up @@ -120,7 +154,7 @@ func TestCloneMapDeep(t *testing.T) {
}

if !cmp.Equal(cloned, input) {
t.Fatalf("Unpected result:\n\n%s\n", renderDiff(input, cloned))
t.Fatalf("Unexpected result:\n\n%s\n", renderDiff(input, cloned))
}

helloList := input["hello"].([]any)
Expand Down Expand Up @@ -164,7 +198,7 @@ func TestCloneSliceDeep(t *testing.T) {
}

if !cmp.Equal(cloned, input) {
t.Fatalf("Unpected result:\n\n%s\n", renderDiff(input, cloned))
t.Fatalf("Unexpected result:\n\n%s\n", renderDiff(input, cloned))
}

helloObj := input[3].(map[string]any)
Expand All @@ -175,6 +209,34 @@ func TestCloneSliceDeep(t *testing.T) {
}
}

func TestCustomCopier(t *testing.T) {
data := map[string]any{"foo": "bar"}
cc := customCopier{Value: data}

cloned, err := Clone(cc)
if err != nil {
t.Fatalf("Failed to clone: %v", err)
}

if !cmp.Equal(cc, cloned) {
t.Fatalf("Unexpected result:\n\n%s\n", renderDiff(cc, cloned))
}
}

func TestCustomPtrCopier(t *testing.T) {
data := map[string]any{"foo": "bar"}
cc := &customPtrCopier{Value: data}

cloned, err := Clone(cc)
if err != nil {
t.Fatalf("Failed to clone: %v", err)
}

if !cmp.Equal(cc, cloned) {
t.Fatalf("Unexpected result:\n\n%s\n", renderDiff(cc, cloned))
}
}

func renderDiff(expected any, actual any) string {
var builder strings.Builder

Expand Down

0 comments on commit 7bea6fb

Please sign in to comment.