Skip to content

Commit

Permalink
Use arshal flags instead of coder flags
Browse files Browse the repository at this point in the history
Instead of consulting the flags on the coder via:
	export.Encoder(enc).Flags
	export.Decoder(dec).Flags
consult flags on the call-provided arshal options via:
	mo.Flags
	uo.Flags

This avoids unnecessarily peaking at the internals flags
of Encoder or Decoder.

There should be no behavior changes as a result of this
since the arshal option flags should be identical.

There are 6 entry points to the "json" package:
*	Marshal
*	MarshalWrite
*	MarshaEncode
*	Unmarshal
*	UnmarshalRead
*	UnmarshalDecode

4 of them (Marshal, MarshalWrite, Unmarshal, UnmarshalRead)
obtain a coder from an internal pool,
where the arshal options plumbed down the call stack
is identical to the one inside the coder.
Therefore, there is no behavior difference for these 4 calls.

2 of them (MarshalEncode, UnmarshalDecode) take in a
user-provided coder, where there could theoretically
be a difference between the options struct and the coder options.
However, in both functions, we obtain a jsonopts.Struct locally
from an internal pool and then call jsonopts.Struct.CopyCoderOptions
where the coder options are copied from the user-provided coder
into the local jsonopts.Struct. Thus, the options struct that
we plumb down the stack is gauranteed to be a superset of
the options in the user-provided coder.
  • Loading branch information
dsnet committed Dec 28, 2024
1 parent 400546f commit 449871b
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 53 deletions.
10 changes: 4 additions & 6 deletions arshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,8 @@ func marshalEncode(out *jsontext.Encoder, in any, mo *jsonopts.Struct) (err erro
marshal, _ = mo.Marshalers.(*Marshalers).lookup(marshal, t)
}
if err := marshal(out, va, mo); err != nil {
xe := export.Encoder(out)
if !xe.Flags.Get(jsonflags.AllowDuplicateNames) {
xe.Tokens.InvalidateDisabledNamespaces()
if !mo.Flags.Get(jsonflags.AllowDuplicateNames) {
export.Encoder(out).Tokens.InvalidateDisabledNamespaces()
}
return err
}
Expand Down Expand Up @@ -461,9 +460,8 @@ func unmarshalDecode(in *jsontext.Decoder, out any, uo *jsonopts.Struct) (err er
unmarshal, _ = uo.Unmarshalers.(*Unmarshalers).lookup(unmarshal, t)
}
if err := unmarshal(in, va, uo); err != nil {
xd := export.Decoder(in)
if !xd.Flags.Get(jsonflags.AllowDuplicateNames) {
xd.Tokens.InvalidateDisabledNamespaces()
if !uo.Flags.Get(jsonflags.AllowDuplicateNames) {
export.Decoder(in).Tokens.InvalidateDisabledNamespaces()
}
return err
}
Expand Down
8 changes: 4 additions & 4 deletions arshal_any.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func marshalObjectAny(enc *jsontext.Encoder, obj map[string]any, mo *jsonopts.St
return enc.WriteToken(jsontext.Null)
}
// Optimize for marshaling an empty map without any preceding whitespace.
if !xe.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
if !mo.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = append(xe.Tokens.MayAppendDelim(xe.Buf, '{'), "{}"...)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand All @@ -118,7 +118,7 @@ func marshalObjectAny(enc *jsontext.Encoder, obj map[string]any, mo *jsonopts.St
}
// A Go map guarantees that each entry has a unique key
// The only possibility of duplicates is due to invalid UTF-8.
if !xe.Flags.Get(jsonflags.AllowInvalidUTF8) {
if !mo.Flags.Get(jsonflags.AllowInvalidUTF8) {
xe.Tokens.Last.DisableNamespace()
}
if !mo.Flags.Get(jsonflags.Deterministic) || len(obj) <= 1 {
Expand Down Expand Up @@ -168,7 +168,7 @@ func unmarshalObjectAny(dec *jsontext.Decoder, uo *jsonopts.Struct) (map[string]
obj := make(map[string]any)
// A Go map guarantees that each entry has a unique key
// The only possibility of duplicates is due to invalid UTF-8.
if !xd.Flags.Get(jsonflags.AllowInvalidUTF8) {
if !uo.Flags.Get(jsonflags.AllowInvalidUTF8) {
xd.Tokens.Last.DisableNamespace()
}
for dec.PeekKind() != '}' {
Expand Down Expand Up @@ -217,7 +217,7 @@ func marshalArrayAny(enc *jsontext.Encoder, arr []any, mo *jsonopts.Struct) erro
return enc.WriteToken(jsontext.Null)
}
// Optimize for marshaling an empty slice without any preceding whitespace.
if !xe.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
if !mo.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = append(xe.Tokens.MayAppendDelim(xe.Buf, '['), "[]"...)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down
44 changes: 22 additions & 22 deletions arshal_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func makeBoolArshaler(t reflect.Type) *arshaler {
}

// Optimize for marshaling without preceding whitespace.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !mo.Flags.Get(jsonflags.StringifyBoolsAndStrings) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace|jsonflags.StringifyBoolsAndStrings) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = strconv.AppendBool(xe.Tokens.MayAppendDelim(xe.Buf, 't'), va.Bool())
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -200,7 +200,7 @@ func makeStringArshaler(t reflect.Type) *arshaler {

// Optimize for marshaling without preceding whitespace or string escaping.
s := va.String()
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !mo.Flags.Get(jsonflags.StringifyBoolsAndStrings) && !xe.Tokens.Last.NeedObjectName() && !jsonwire.NeedEscape(s) {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace|jsonflags.StringifyBoolsAndStrings) && !xe.Tokens.Last.NeedObjectName() && !jsonwire.NeedEscape(s) {
b := xe.Buf
b = xe.Tokens.MayAppendDelim(b, '"')
b = append(b, '"')
Expand Down Expand Up @@ -446,7 +446,7 @@ func makeIntArshaler(t reflect.Type) *arshaler {
}

// Optimize for marshaling without preceding whitespace or string escaping.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !mo.Flags.Get(jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace|jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = strconv.AppendInt(xe.Tokens.MayAppendDelim(xe.Buf, '0'), va.Int(), 10)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -533,7 +533,7 @@ func makeUintArshaler(t reflect.Type) *arshaler {
}

// Optimize for marshaling without preceding whitespace or string escaping.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !mo.Flags.Get(jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace|jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = strconv.AppendUint(xe.Tokens.MayAppendDelim(xe.Buf, '0'), va.Uint(), 10)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -625,7 +625,7 @@ func makeFloatArshaler(t reflect.Type) *arshaler {
}

// Optimize for marshaling without preceding whitespace or string escaping.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !mo.Flags.Get(jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace|jsonflags.StringifyNumbers) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = jsonwire.AppendFloat(xe.Tokens.MayAppendDelim(xe.Buf, '0'), fv, bits)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -755,7 +755,7 @@ func makeMapArshaler(t reflect.Type) *arshaler {
return enc.WriteToken(jsontext.Null)
}
// Optimize for marshaling an empty map without any preceding whitespace.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = append(xe.Tokens.MayAppendDelim(xe.Buf, '{'), "{}"...)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -785,7 +785,7 @@ func makeMapArshaler(t reflect.Type) *arshaler {
// A Go map guarantees that each entry has a unique key.
// As such, disable the expensive duplicate name check if we know
// that every Go key will serialize as a unique JSON string.
if !nonDefaultKey && mapKeyWithUniqueRepresentation(k.Kind(), xe.Flags.Get(jsonflags.AllowInvalidUTF8)) {
if !nonDefaultKey && mapKeyWithUniqueRepresentation(k.Kind(), mo.Flags.Get(jsonflags.AllowInvalidUTF8)) {
xe.Tokens.Last.DisableNamespace()
}

Expand Down Expand Up @@ -913,7 +913,7 @@ func makeMapArshaler(t reflect.Type) *arshaler {
// will be rejected as duplicates since they semantically refer
// to the same Go value. This is an unusual interaction
// between syntax and semantics, but is more correct.
if !nonDefaultKey && mapKeyWithUniqueRepresentation(k.Kind(), xd.Flags.Get(jsonflags.AllowInvalidUTF8)) {
if !nonDefaultKey && mapKeyWithUniqueRepresentation(k.Kind(), uo.Flags.Get(jsonflags.AllowInvalidUTF8)) {
xd.Tokens.Last.DisableNamespace()
}

Expand All @@ -922,7 +922,7 @@ func makeMapArshaler(t reflect.Type) *arshaler {
// since existing presence alone is insufficient to indicate
// whether the input had a duplicate name.
var seen reflect.Value
if !xd.Flags.Get(jsonflags.AllowDuplicateNames) && va.Len() > 0 {
if !uo.Flags.Get(jsonflags.AllowDuplicateNames) && va.Len() > 0 {
seen = reflect.MakeMap(reflect.MapOf(k.Type(), emptyStructType))
}

Expand All @@ -941,7 +941,7 @@ func makeMapArshaler(t reflect.Type) *arshaler {
}

if v2 := va.MapIndex(k.Value); v2.IsValid() {
if !xd.Flags.Get(jsonflags.AllowDuplicateNames) && (!seen.IsValid() || seen.MapIndex(k.Value).IsValid()) {
if !uo.Flags.Get(jsonflags.AllowDuplicateNames) && (!seen.IsValid() || seen.MapIndex(k.Value).IsValid()) {
// TODO: Unread the object name.
name := xd.PreviousTokenOrValue()
return newDuplicateNameError(dec.StackPointer(), nil, dec.InputOffset()-len64(name))
Expand Down Expand Up @@ -1078,20 +1078,20 @@ func makeStructArshaler(t reflect.Type) *arshaler {
b := xe.Buf
if xe.Tokens.Last.Length() > 0 {
b = append(b, ',')
if xe.Flags.Get(jsonflags.SpaceAfterComma) {
if mo.Flags.Get(jsonflags.SpaceAfterComma) {
b = append(b, ' ')
}
}
if xe.Flags.Get(jsonflags.Multiline) {
if mo.Flags.Get(jsonflags.Multiline) {
b = xe.AppendIndent(b, xe.Tokens.NeedIndent('"'))
}

// Append the token to the output and to the state machine.
n0 := len(b) // offset before calling AppendQuote
if !xe.Flags.Get(jsonflags.EscapeForHTML | jsonflags.EscapeForJS | jsonflags.EscapeInvalidUTF8) {
if !mo.Flags.Get(jsonflags.EscapeForHTML | jsonflags.EscapeForJS | jsonflags.EscapeInvalidUTF8) {
b = append(b, f.quotedName...)
} else {
b, _ = jsonwire.AppendQuote(b, f.name, &xe.Flags)
b, _ = jsonwire.AppendQuote(b, f.name, &mo.Flags)
}
xe.Buf = b
xe.Names.ReplaceLastQuotedOffset(n0)
Expand Down Expand Up @@ -1136,14 +1136,14 @@ func makeStructArshaler(t reflect.Type) *arshaler {
// Remember the previous written object member.
// The set of seen fields only needs to be updated to detect
// duplicate names with those from the inlined fallback.
if !xe.Flags.Get(jsonflags.AllowDuplicateNames) && fields.inlinedFallback != nil {
if !mo.Flags.Get(jsonflags.AllowDuplicateNames) && fields.inlinedFallback != nil {
seenIdxs.insert(uint(f.id))
}
prevIdx = f.id
}
if fields.inlinedFallback != nil && !(mo.Flags.Get(jsonflags.DiscardUnknownMembers) && fields.inlinedFallback.unknown) {
var insertUnquotedName func([]byte) bool
if !xe.Flags.Get(jsonflags.AllowDuplicateNames) {
if !mo.Flags.Get(jsonflags.AllowDuplicateNames) {
insertUnquotedName = func(name []byte) bool {
// Check that the name from inlined fallback does not match
// one of the previously marshaled names from known fields.
Expand Down Expand Up @@ -1215,7 +1215,7 @@ func makeStructArshaler(t reflect.Type) *arshaler {
if uo.Flags.Get(jsonflags.RejectUnknownMembers) && (fields.inlinedFallback == nil || fields.inlinedFallback.unknown) {
return newUnmarshalErrorAfter(dec, t, ErrUnknownName)
}
if !xd.Flags.Get(jsonflags.AllowDuplicateNames) && !xd.Namespaces.Last().InsertUnquoted(name) {
if !uo.Flags.Get(jsonflags.AllowDuplicateNames) && !xd.Namespaces.Last().InsertUnquoted(name) {
// TODO: Unread the object name.
return newDuplicateNameError(dec.StackPointer(), nil, dec.InputOffset()-len64(val))
}
Expand All @@ -1234,7 +1234,7 @@ func makeStructArshaler(t reflect.Type) *arshaler {
continue
}
}
if !xd.Flags.Get(jsonflags.AllowDuplicateNames) && !seenIdxs.insert(uint(f.id)) {
if !uo.Flags.Get(jsonflags.AllowDuplicateNames) && !seenIdxs.insert(uint(f.id)) {
// TODO: Unread the object name.
return newDuplicateNameError(dec.StackPointer(), nil, dec.InputOffset()-len64(val))
}
Expand Down Expand Up @@ -1383,7 +1383,7 @@ func makeSliceArshaler(t reflect.Type) *arshaler {
return enc.WriteToken(jsontext.Null)
}
// Optimize for marshaling an empty slice without any preceding whitespace.
if optimizeCommon && !xe.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
if optimizeCommon && !mo.Flags.Get(jsonflags.AnyWhitespace) && !xe.Tokens.Last.NeedObjectName() {
xe.Buf = append(xe.Tokens.MayAppendDelim(xe.Buf, '['), "[]"...)
xe.Tokens.Last.Increment()
if xe.NeedFlush() {
Expand Down Expand Up @@ -1624,7 +1624,7 @@ func makePointerArshaler(t reflect.Type) *arshaler {
return err
}
if uo.Flags.Get(jsonflags.StringifyWithLegacySemantics) &&
(uo.Flags.Get(jsonflags.StringifyNumbers) || uo.Flags.Get(jsonflags.StringifyBoolsAndStrings)) {
(uo.Flags.Get(jsonflags.StringifyNumbers | jsonflags.StringifyBoolsAndStrings)) {
// A JSON null quoted within a JSON string should take effect
// within the pointer value, rather than the indirect value.
//
Expand Down Expand Up @@ -1664,7 +1664,7 @@ func makeInterfaceArshaler(t reflect.Type) *arshaler {
}
// Optimize for the any type if there are no special options.
if optimizeCommon &&
t == anyType && !mo.Flags.Get(jsonflags.StringifyNumbers) && !mo.Flags.Get(jsonflags.StringifyBoolsAndStrings) && mo.Format == "" &&
t == anyType && !mo.Flags.Get(jsonflags.StringifyNumbers|jsonflags.StringifyBoolsAndStrings) && mo.Format == "" &&
(mo.Marshalers == nil || !mo.Marshalers.(*Marshalers).fromAny) {
return marshalValueAny(enc, va.Elem().Interface(), mo)
}
Expand Down Expand Up @@ -1709,7 +1709,7 @@ func makeInterfaceArshaler(t reflect.Type) *arshaler {
// Duplicate name check must be enforced since unmarshalValueAny
// does not implement merge semantics.
if optimizeCommon &&
t == anyType && !xd.Flags.Get(jsonflags.AllowDuplicateNames) && uo.Format == "" &&
t == anyType && !uo.Flags.Get(jsonflags.AllowDuplicateNames) && uo.Format == "" &&
(uo.Unmarshalers == nil || !uo.Unmarshalers.(*Unmarshalers).fromAny) {
v, err := unmarshalValueAny(dec, uo)
// We must check for nil interface values up front.
Expand Down
16 changes: 8 additions & 8 deletions arshal_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func MarshalFuncV1[T any](fn func(T) ([]byte, error)) *Marshalers {
val, err := fn(va.castTo(t).Interface().(T))
if err != nil {
err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
if export.Encoder(enc).Flags.Get(jsonflags.ReportLegacyErrorValues) {
if mo.Flags.Get(jsonflags.ReportLegacyErrorValues) {
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFuncV1") // unlike unmarshal, always wrapped
}
err = newMarshalErrorBefore(enc, t, err)
Expand Down Expand Up @@ -216,9 +216,9 @@ func MarshalFuncV2[T any](fn func(*jsontext.Encoder, T, Options) error) *Marshal
fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
xe := export.Encoder(enc)
prevDepth, prevLength := xe.Tokens.DepthLength()
xe.Flags.Set(jsonflags.WithinArshalCall | 1)
mo.Flags.Set(jsonflags.WithinArshalCall | 1)
err := fn(enc, va.castTo(t).Interface().(T), mo)
xe.Flags.Set(jsonflags.WithinArshalCall | 0)
mo.Flags.Set(jsonflags.WithinArshalCall | 0)
currDepth, currLength := xe.Tokens.DepthLength()
if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
err = errNonSingularValue
Expand All @@ -230,7 +230,7 @@ func MarshalFuncV2[T any](fn func(*jsontext.Encoder, T, Options) error) *Marshal
}
err = errSkipMutation
}
if xe.Flags.Get(jsonflags.ReportLegacyErrorValues) {
if mo.Flags.Get(jsonflags.ReportLegacyErrorValues) {
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFuncV2") // unlike unmarshal, always wrapped
}
if !export.IsIOError(err) {
Expand Down Expand Up @@ -267,7 +267,7 @@ func UnmarshalFuncV1[T any](fn func([]byte, T) error) *Unmarshalers {
err = fn(val, va.castTo(t).Interface().(T))
if err != nil {
err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
if export.Decoder(dec).Flags.Get(jsonflags.ReportLegacyErrorValues) {
if uo.Flags.Get(jsonflags.ReportLegacyErrorValues) {
return err // unlike marshal, never wrapped
}
err = newUnmarshalErrorAfter(dec, t, err)
Expand Down Expand Up @@ -298,9 +298,9 @@ func UnmarshalFuncV2[T any](fn func(*jsontext.Decoder, T, Options) error) *Unmar
fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
xd := export.Decoder(dec)
prevDepth, prevLength := xd.Tokens.DepthLength()
xd.Flags.Set(jsonflags.WithinArshalCall | 1)
uo.Flags.Set(jsonflags.WithinArshalCall | 1)
err := fn(dec, va.castTo(t).Interface().(T), uo)
xd.Flags.Set(jsonflags.WithinArshalCall | 0)
uo.Flags.Set(jsonflags.WithinArshalCall | 0)
currDepth, currLength := xd.Tokens.DepthLength()
if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
err = errNonSingularValue
Expand All @@ -312,7 +312,7 @@ func UnmarshalFuncV2[T any](fn func(*jsontext.Decoder, T, Options) error) *Unmar
}
err = errSkipMutation
}
if export.Decoder(dec).Flags.Get(jsonflags.ReportLegacyErrorValues) {
if uo.Flags.Get(jsonflags.ReportLegacyErrorValues) {
return err // unlike marshal, never wrapped
}
if !isSyntacticError(err) && !export.IsIOError(err) {
Expand Down
3 changes: 1 addition & 2 deletions arshal_inlined.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ func marshalInlinedFallbackAll(enc *jsontext.Encoder, va addressableValue, mo *j
mk := newAddressableValue(m.Type().Key())
mv := newAddressableValue(m.Type().Elem())
marshalKey := func(mk addressableValue) error {
xe := export.Encoder(enc)
b, err := jsonwire.AppendQuote(enc.UnusedBuffer(), mk.String(), &xe.Flags)
b, err := jsonwire.AppendQuote(enc.UnusedBuffer(), mk.String(), &mo.Flags)
if err != nil {
return newMarshalErrorBefore(enc, m.Type().Key(), err)
}
Expand Down
Loading

0 comments on commit 449871b

Please sign in to comment.