Skip to content

Commit

Permalink
rework JSONCodec.PlanScan
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-roehrich committed Jan 22, 2025
1 parent 0bc29e3 commit a5353af
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 74 deletions.
95 changes: 45 additions & 50 deletions pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
}
}

// JSON needs its on scan plan for pointers to handle 'null'::json(b).
// Consider making pointerPointerScanPlan more flexible in the future.
type jsonPointerScanPlan struct {
next ScanPlan
}

func (p jsonPointerScanPlan) Scan(src []byte, dst any) error {
el := reflect.ValueOf(dst).Elem()
if src == nil || string(src) == "null" {
el.SetZero()
return nil
}

el.Set(reflect.New(el.Type().Elem()))
if p.next != nil {
return p.next.Scan(src, el.Interface())
}

return nil
}

type encodePlanJSONCodecEitherFormatString struct{}

func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
Expand Down Expand Up @@ -117,64 +138,38 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (
return buf, nil
}

func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan {
return c.planScan(m, oid, formatCode, target, 0)
}

// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b),
// so we need to duplicate the logic here.
func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan {
if depth > 8 {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}

switch target.(type) {
case *string:
return scanPlanAnyToString{}

case **string:
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
// solution would be.
//
// https://github.com/jackc/pgx/issues/1470 -- **string
// https://github.com/jackc/pgx/issues/1691 -- ** anything else

if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
}
}

return &scanPlanAnyToString{}
case *[]byte:
return scanPlanJSONToByteSlice{}
return &scanPlanJSONToByteSlice{}
case BytesScanner:
return scanPlanBinaryBytesToBytesScanner{}

}

// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
//
// https://github.com/jackc/pgx/issues/1418
if isSQLScanner(target) {
return &scanPlanSQLScanner{formatCode: format}
return &scanPlanBinaryBytesToBytesScanner{}
case sql.Scanner:
return &scanPlanSQLScanner{formatCode: formatCode}
}

return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
rv := reflect.ValueOf(target)
if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer {
var plan jsonPointerScanPlan
plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1)
return plan
} else {
return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal}
}
}

// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
//
// https://github.com/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
if _, is := v.(sql.Scanner); is {
return true
}

val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
return true
}
val = val.Elem()
}
return false
}

type scanPlanAnyToString struct{}

func (scanPlanAnyToString) Scan(src []byte, dst any) error {
Expand Down Expand Up @@ -202,7 +197,7 @@ type scanPlanJSONToJSONUnmarshal struct {
}

func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil {
if src == nil || string(src) == "null" {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
el := dstValue.Elem()
Expand Down
31 changes: 31 additions & 0 deletions pgtype/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,34 @@ func TestJSONCodecScanToNonPointerValues(t *testing.T) {
require.Equal(t, 42, m)
})
}

func TestJSONCodecScanNull(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var dest struct{}
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")

err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")

var destPointer *struct{}
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer)
require.NoError(t, err)
require.Nil(t, destPointer)

err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer)
require.NoError(t, err)
require.Nil(t, destPointer)
})
}

func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var dest *Issue2146
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
require.NoError(t, err)
require.Nil(t, dest)
})
}
25 changes: 1 addition & 24 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,7 @@ type scanPlanSQLScanner struct {
}

func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
scanner := getSQLScanner(dst)

if scanner == nil {
return fmt.Errorf("cannot scan into %T", dst)
}
scanner := dst.(sql.Scanner)

if src == nil {
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
Expand All @@ -413,25 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
}
}

// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
if sc, is := target.(sql.Scanner); is {
return sc
}

val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return val.Interface().(sql.Scanner)
}
val = val.Elem()
}
return nil
}

type scanPlanString struct{}

func (scanPlanString) Scan(src []byte, dst any) error {
Expand Down

0 comments on commit a5353af

Please sign in to comment.