Skip to content

Commit

Permalink
fix: #2146
Browse files Browse the repository at this point in the history
  • Loading branch information
ludusrusso committed Oct 19, 2024
1 parent 2ec9004 commit 212a92d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
18 changes: 17 additions & 1 deletion pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
case BytesScanner:
return scanPlanBinaryBytesToBytesScanner{}

}

// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
//
// https:/jackc/pgx/issues/1418
case sql.Scanner:
if isSQLScanner(target) {
return &scanPlanSQLScanner{formatCode: format}
}

Expand All @@ -155,6 +157,20 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
}
}

// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
//
// https:/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
return true
}
val = val.Elem()
}
return false
}

type scanPlanAnyToString struct{}

func (scanPlanAnyToString) Scan(src []byte, dst any) error {
Expand Down
27 changes: 27 additions & 0 deletions pgtype/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func TestJSONCodec(t *testing.T) {

// Test driver.Valuer is used before json.Marshaler (https:/jackc/pgx/issues/1805)
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
// Test driver.Scanner is used before json.Unmarshaler (https:/jackc/pgx/issues/2146)
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
})

pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
Expand Down Expand Up @@ -109,6 +111,31 @@ func (i Issue1805) MarshalJSON() ([]byte, error) {
return nil, errors.New("MarshalJSON called")
}

type Issue2146 int

func (i *Issue2146) Scan(src any) error {
var source []byte
switch src.(type) {
case string:
source = []byte(src.(string))
case []byte:
source = src.([]byte)
default:
return errors.New("unknown source type")
}
var newI int
if err := json.Unmarshal(source, &newI); err != nil {
return err
}
*i = Issue2146(newI + 1)
return nil
}

func (i Issue2146) Value() (driver.Value, error) {
b, err := json.Marshal(int(i - 1))
return string(b), err
}

// https:/jackc/pgx/issues/1273#issuecomment-1221414648
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
Expand Down
22 changes: 21 additions & 1 deletion pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,12 @@ type scanPlanSQLScanner struct {
}

func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
scanner := dst.(sql.Scanner)
scanner := getSQLScanner(dst)

if scanner == nil {
return fmt.Errorf("cannot scan into %T", dst)
}

if src == nil {
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
// text format path would be converted to empty string.
Expand All @@ -408,6 +413,21 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
}
}

// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return val.Interface().(sql.Scanner)
}
val = val.Elem()
}
return nil
}

type scanPlanString struct{}

func (scanPlanString) Scan(src []byte, dst any) error {
Expand Down
8 changes: 8 additions & 0 deletions pgtype/pgtype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net"
"os"
"reflect"
"regexp"
"strconv"
"testing"
Expand Down Expand Up @@ -631,3 +632,10 @@ func isExpectedEq(a any) func(any) bool {
return a == v
}
}

func isPtrExpectedEq(a any) func(any) bool {
return func(v any) bool {
val := reflect.ValueOf(v)
return a == val.Elem().Interface()
}
}

0 comments on commit 212a92d

Please sign in to comment.