Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert: use type constraints (generics) #255

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ import (

// BoolOrComparison can be a bool, cmp.Comparison, or error. See Assert for
// details about how this type is used.
type BoolOrComparison interface{}
type BoolOrComparison interface {
bool | func() (bool, string) | ~func() cmp.Result
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might need to include error. I believe it was previously possible to use assert.Assert(t, err).

}

// TestingT is the subset of testing.T used by the assert package.
type TestingT interface {
FailNow()
Fail()
Log(args ...interface{})
Log(args ...any)
}

type helperT interface {
Expand Down Expand Up @@ -138,7 +140,7 @@ type helperT interface {
// Assert uses t.FailNow to fail the test. Like t.FailNow, Assert must be called
// from the goroutine running the test function, not from other
// goroutines created during the test. Use Check from other goroutines.
func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) {
func Assert[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -152,7 +154,7 @@ func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{})
// is successful Check returns true. Check may be called from any goroutine.
//
// See Assert for details about the comparison arg and failure messages.
func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool {
func Check[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...any) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -169,7 +171,7 @@ func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) b
// NilError uses t.FailNow to fail the test. Like t.FailNow, NilError must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check from other goroutines.
func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
func NilError(t TestingT, err error, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand Down Expand Up @@ -197,7 +199,7 @@ func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check with cmp.Equal from other
// goroutines.
func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
func Equal[ANY any](t TestingT, x, y ANY, msgAndArgs ...any) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this one could/should be the comparable constraint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I did attempt to use comparable here while working on this. It works for some types, but a few tests failed. Currently Equal works with pointers and interfaces like error. Changing to comparable caused the build to fail because error is not comparable. So I had to use any.

if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -216,7 +218,7 @@ func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check with cmp.DeepEqual from other
// goroutines.
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
func DeepEqual[ANY any](t TestingT, x, y ANY, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -235,7 +237,7 @@ func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check with cmp.Error from other
// goroutines.
func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) {
func Error(t TestingT, err error, expected string, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -252,7 +254,7 @@ func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) {
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check with cmp.ErrorContains from other
// goroutines.
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) {
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand Down Expand Up @@ -286,7 +288,7 @@ func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interf
// goroutines.
//
// Deprecated: Use ErrorIs
func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) {
func ErrorType(t TestingT, err error, expected any, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand All @@ -303,7 +305,7 @@ func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interf
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use Check with cmp.ErrorIs from other
// goroutines.
func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...interface{}) {
func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand Down
61 changes: 34 additions & 27 deletions assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,42 @@ func (c exampleComparison) Compare() (bool, string) {
return c.success, c.message
}

func TestAssertWithComparisonSuccess(t *testing.T) {
fakeT := &fakeTestingT{}

cmp := exampleComparison{success: true}
Assert(fakeT, cmp.Compare)
expectSuccess(t, fakeT)
}

func TestAssertWithComparisonFailure(t *testing.T) {
fakeT := &fakeTestingT{}

cmp := exampleComparison{message: "oops, not good"}
Assert(fakeT, cmp.Compare)
expectFailNowed(t, fakeT, "assertion failed: oops, not good")
}
func TestAssert_ArgumentTypes(t *testing.T) {
t.Run("compare function success", func(t *testing.T) {
fakeT := &fakeTestingT{}
cmp := exampleComparison{success: true}
Assert(fakeT, cmp.Compare)
expectSuccess(t, fakeT)
})
t.Run("compare function failure", func(t *testing.T) {
fakeT := &fakeTestingT{}
cmp := exampleComparison{message: "oops, not good"}
Assert(fakeT, cmp.Compare)
expectFailNowed(t, fakeT, "assertion failed: oops, not good")
})
t.Run("compare function failure with extra message", func(t *testing.T) {
fakeT := &fakeTestingT{}
cmp := exampleComparison{message: "oops, not good"}
Assert(fakeT, cmp.Compare, "extra stuff %v", true)
expectFailNowed(t, fakeT, "assertion failed: oops, not good: extra stuff true")
})

func TestAssertWithComparisonAndExtraMessage(t *testing.T) {
fakeT := &fakeTestingT{}
t.Run("bool", func(t *testing.T) {
fakeT := &fakeTestingT{}
Assert(fakeT, true)
expectSuccess(t, fakeT)
Assert(fakeT, false)
expectFailNowed(t, fakeT, "assertion failed: false is false")
})

cmp := exampleComparison{message: "oops, not good"}
Assert(fakeT, cmp.Compare, "extra stuff %v", true)
expectFailNowed(t, fakeT, "assertion failed: oops, not good: extra stuff true")
t.Run("result function", func(t *testing.T) {
fn := func() cmp.Result {
return cmp.ResultSuccess
}
fakeT := &fakeTestingT{}
Assert(fakeT, fn)
expectSuccess(t, fakeT)
})
}

type customError struct {
Expand Down Expand Up @@ -269,13 +283,6 @@ func TestEqualFailure(t *testing.T) {
expectFailNowed(t, fakeT, "assertion failed: 1 (actual int) != 3 (expected int)")
}

func TestEqualFailureTypes(t *testing.T) {
fakeT := &fakeTestingT{}

Equal(fakeT, 3, uint(3))
expectFailNowed(t, fakeT, `assertion failed: 3 (int) != 3 (uint)`)
}

func TestEqualFailureWithSelectorArgument(t *testing.T) {
fakeT := &fakeTestingT{}

Expand Down
30 changes: 16 additions & 14 deletions assert/cmp/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Comparison func() Result
// The comparison can be customized using comparison Options.
// Package http://pkg.go.dev/gotest.tools/v3/assert/opt provides some additional
// commonly used Options.
func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
func DeepEqual[ANY any](x, y ANY, opts ...cmp.Option) Comparison {
return func() (result Result) {
defer func() {
if panicmsg, handled := handleCmpPanic(recover()); handled {
Expand Down Expand Up @@ -63,7 +63,9 @@ func toResult(success bool, msg string) Result {

// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
// regexp pattern.
type RegexOrPattern interface{}
type RegexOrPattern interface {
~string | *regexp.Regexp
}

// Regexp succeeds if value v matches regular expression re.
//
Expand All @@ -72,15 +74,15 @@ type RegexOrPattern interface{}
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
// r := regexp.MustCompile("^[0-9a-f]{32}$")
// assert.Assert(t, cmp.Regexp(r, str))
func Regexp(re RegexOrPattern, v string) Comparison {
func Regexp[R RegexOrPattern](re R, v string) Comparison {
match := func(re *regexp.Regexp) Result {
return toResult(
re.MatchString(v),
fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
}

return func() Result {
switch regex := re.(type) {
switch regex := any(re).(type) {
case *regexp.Regexp:
return match(regex)
case string:
Expand All @@ -96,13 +98,13 @@ func Regexp(re RegexOrPattern, v string) Comparison {
}

// Equal succeeds if x == y. See assert.Equal for full documentation.
func Equal(x, y interface{}) Comparison {
func Equal[ANY any](x, y ANY) Comparison {
return func() Result {
switch {
case x == y:
case any(x) == any(y):
return ResultSuccess
case isMultiLineStringCompare(x, y):
diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
diff := format.UnifiedDiff(format.DiffConfig{A: any(x).(string), B: any(y).(string)})
return multiLineDiffResult(diff, x, y)
}
return ResultFailureTemplate(`
Expand All @@ -117,7 +119,7 @@ func Equal(x, y interface{}) Comparison {
}
}

func isMultiLineStringCompare(x, y interface{}) bool {
func isMultiLineStringCompare(x, y any) bool {
strX, ok := x.(string)
if !ok {
return false
Expand All @@ -129,7 +131,7 @@ func isMultiLineStringCompare(x, y interface{}) bool {
return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
}

func multiLineDiffResult(diff string, x, y interface{}) Result {
func multiLineDiffResult(diff string, x, y any) Result {
return ResultFailureTemplate(`
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
Expand All @@ -138,7 +140,7 @@ func multiLineDiffResult(diff string, x, y interface{}) Result {
}

// Len succeeds if the sequence has the expected length.
func Len(seq interface{}, expected int) Comparison {
func Len(seq any, expected int) Comparison {
return func() (result Result) {
defer func() {
if e := recover(); e != nil {
Expand All @@ -163,7 +165,7 @@ func Len(seq interface{}, expected int) Comparison {
// If collection is a Map, contains will succeed if item is a key in the map.
// If collection is a slice or array, item is compared to each item in the
// sequence using reflect.DeepEqual().
func Contains(collection interface{}, item interface{}) Comparison {
func Contains(collection any, item any) Comparison {
return func() Result {
colValue := reflect.ValueOf(collection)
if !colValue.IsValid() {
Expand Down Expand Up @@ -261,14 +263,14 @@ func formatErrorMessage(err error) string {
//
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
// maps, and channels.
func Nil(obj interface{}) Comparison {
func Nil(obj any) Comparison {
msgFunc := func(value reflect.Value) string {
return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
}
return isNil(obj, msgFunc)
}

func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
func isNil(obj any, msgFunc func(reflect.Value) string) Comparison {
return func() Result {
if obj == nil {
return ResultSuccess
Expand Down Expand Up @@ -309,7 +311,7 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
// Fails if err does not implement the reflect.Type.
//
// Deprecated: Use ErrorIs
func ErrorType(err error, expected interface{}) Comparison {
func ErrorType(err error, expected any) Comparison {
return func() Result {
switch expectedType := expected.(type) {
case func(error) bool:
Expand Down
26 changes: 11 additions & 15 deletions assert/cmp/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ func TestDeepEqualWithUnexported(t *testing.T) {
}

func TestRegexp(t *testing.T) {
var testcases = []struct {
type testCase struct {
name string
regex interface{}
regex string
value string
match bool
expErr string
}{
}

var testcases = []testCase{
{
name: "pattern string match",
regex: "^[0-9]+$",
Expand All @@ -70,24 +72,12 @@ func TestRegexp(t *testing.T) {
value: "2123423456",
expErr: `value "2123423456" does not match regexp "^1"`,
},
{
name: "regexp match",
regex: regexp.MustCompile("^d[0-9a-f]{8}$"),
value: "d1632beef",
match: true,
},
{
name: "invalid regexp",
regex: "^1(",
value: "2",
expErr: "error parsing regexp: missing closing ): `^1(`",
},
{
name: "invalid type",
regex: struct{}{},
value: "some string",
expErr: "invalid type struct {} for regex pattern",
},
}

for _, tc := range testcases {
Expand All @@ -100,6 +90,12 @@ func TestRegexp(t *testing.T) {
}
})
}

t.Run("regexp match", func(t *testing.T) {
regex := regexp.MustCompile("^d[0-9a-f]{8}$")
res := Regexp(regex, "d1632beef")()
assertSuccess(t, res)
})
}

func TestLen(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion fs/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ExampleNewFile() {

content, err := os.ReadFile(file.Path())
assert.NilError(t, err)
assert.Equal(t, "content\n", content)
assert.Equal(t, "content\n", string(content))
}

// Create a directory and subdirectory with files
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module gotest.tools/v3

go 1.17
go 1.18

require (
github.com/google/go-cmp v0.5.9
Expand Down
6 changes: 3 additions & 3 deletions internal/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const failureMessage = "assertion failed: "
func Eval(
t LogT,
argSelector argSelector,
comparison interface{},
msgAndArgs ...interface{},
comparison any,
msgAndArgs ...any,
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
Expand Down Expand Up @@ -79,7 +79,7 @@ func runCompareFunc(
return true
}

func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
func logFailureFromBool(t LogT, msgAndArgs ...any) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
Expand Down
Loading