diff --git a/pgtype/json.go b/pgtype/json.go index 76cec51b8..6f7ebb51f 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco } } +// JSON needs its on scan plan for pointers to handle 'null'::json(b). +// Consider making pointerPointerScanPlan more flexible in the future. +type jsonPointerScanPlan struct { + next ScanPlan +} + +func (p jsonPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil || string(src) == "null" { + el.SetZero() + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + if p.next != nil { + return p.next.Scan(src, el.Interface()) + } + + return nil +} + type encodePlanJSONCodecEitherFormatString struct{} func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { @@ -117,64 +138,38 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) ( return buf, nil } -func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan { + return c.planScan(m, oid, formatCode, target, 0) +} + +// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b), +// so we need to duplicate the logic here. +func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + switch target.(type) { case *string: - return scanPlanAnyToString{} - - case **string: - // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better - // solution would be. - // - // https://github.com/jackc/pgx/issues/1470 -- **string - // https://github.com/jackc/pgx/issues/1691 -- ** anything else - - if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { - if _, failed := nextPlan.(*scanPlanFail); !failed { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } - + return &scanPlanAnyToString{} case *[]byte: - return scanPlanJSONToByteSlice{} + return &scanPlanJSONToByteSlice{} case BytesScanner: - return scanPlanBinaryBytesToBytesScanner{} - - } - - // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. - // - // https://github.com/jackc/pgx/issues/1418 - if isSQLScanner(target) { - return &scanPlanSQLScanner{formatCode: format} + return &scanPlanBinaryBytesToBytesScanner{} + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: formatCode} } - return &scanPlanJSONToJSONUnmarshal{ - unmarshal: c.Unmarshal, + rv := reflect.ValueOf(target) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer { + var plan jsonPointerScanPlan + plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1) + return plan + } else { + return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal} } } -// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner). -// -// https://github.com/jackc/pgx/issues/2146 -func isSQLScanner(v any) bool { - if _, is := v.(sql.Scanner); is { - return true - } - - val := reflect.ValueOf(v) - for val.Kind() == reflect.Ptr { - if _, ok := val.Interface().(sql.Scanner); ok { - return true - } - val = val.Elem() - } - return false -} - type scanPlanAnyToString struct{} func (scanPlanAnyToString) Scan(src []byte, dst any) error { @@ -202,7 +197,7 @@ type scanPlanJSONToJSONUnmarshal struct { } func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { - if src == nil { + if src == nil || string(src) == "null" { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { el := dstValue.Elem() diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 78bdc3a9d..ebe144aae 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -326,3 +326,34 @@ func TestJSONCodecScanToNonPointerValues(t *testing.T) { require.Equal(t, 42, m) }) } + +func TestJSONCodecScanNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest struct{} + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot scan NULL into *struct {}") + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot scan NULL into *struct {}") + + var destPointer *struct{} + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + }) +} + +func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest *Issue2146 + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.NoError(t, err) + require.Nil(t, dest) + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 20645d694..a1083161c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -396,11 +396,7 @@ type scanPlanSQLScanner struct { } func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { - scanner := getSQLScanner(dst) - - if scanner == nil { - return fmt.Errorf("cannot scan into %T", dst) - } + scanner := dst.(sql.Scanner) if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the @@ -413,25 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { } } -// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively -func getSQLScanner(target any) sql.Scanner { - if sc, is := target.(sql.Scanner); is { - return sc - } - - val := reflect.ValueOf(target) - for val.Kind() == reflect.Ptr { - if _, ok := val.Interface().(sql.Scanner); ok { - if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(sql.Scanner) - } - val = val.Elem() - } - return nil -} - type scanPlanString struct{} func (scanPlanString) Scan(src []byte, dst any) error {