diff --git a/cmd/toml-test-decoder/main.go b/cmd/toml-test-decoder/main.go index 344047f1..a6486fc2 100644 --- a/cmd/toml-test-decoder/main.go +++ b/cmd/toml-test-decoder/main.go @@ -31,13 +31,14 @@ func main() { } var decoded interface{} - if _, err := toml.DecodeReader(os.Stdin, &decoded); err != nil { + meta, err := toml.DecodeReader(os.Stdin, &decoded) + if err != nil { log.Fatalf("Error decoding TOML: %s", err) } j := json.NewEncoder(os.Stdout) j.SetIndent("", " ") - if err := j.Encode(tag.Add("", decoded)); err != nil { + if err := j.Encode(tag.Add(meta, "", decoded)); err != nil { log.Fatalf("Error encoding JSON: %s", err) } } diff --git a/decode.go b/decode.go index 39f7d8bd..6ff2af24 100644 --- a/decode.go +++ b/decode.go @@ -128,11 +128,12 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) { return MetaData{}, err } md := MetaData{ - mapping: p.mapping, - types: p.types, - keys: p.ordered, - decoded: make(map[string]bool, len(p.ordered)), - context: nil, + mapping: p.mapping, + types: p.types, + keys: p.ordered, + comments: p.comments, + decoded: make(map[string]bool, len(p.ordered)), + context: nil, } return md, md.unify(p.mapping, indirect(rv)) } @@ -462,6 +463,7 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro var s string switch sdata := data.(type) { case Marshaler: + fmt.Println("unifyText (Marshaler)", data, "in to", v) text, err := sdata.MarshalTOML() if err != nil { return err @@ -473,6 +475,7 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro return err } s = string(text) + // fmt.Println("unifyText (TextMarshaler)", data, "in to", v, "=", s) case fmt.Stringer: s = sdata.String() case string: diff --git a/encode.go b/encode.go index 0804c653..109883c9 100644 --- a/encode.go +++ b/encode.go @@ -2,6 +2,7 @@ package toml import ( "bufio" + "bytes" "encoding" "errors" "fmt" @@ -98,8 +99,16 @@ type Encoder struct { // String to use for a single indentation level; default is two spaces. Indent string + // TODO(v2): Ident should be a function so we can do: + // + // NewEncoder(os.Stdout).SetIndent("prefix", "indent").MetaData(meta).Encode() + // + // Prefix is also useful to have. + w *bufio.Writer hasWritten bool // written any output to w yet? + wroteNL int // How many newlines do we have in a row? + meta *MetaData } // NewEncoder create a new Encoder. @@ -110,6 +119,17 @@ func NewEncoder(w io.Writer) *Encoder { } } +// MetaData sets the metadata for this encoder. +// +// This can be used to control the formatting; see the documentation of MetaData +// for more details. +// +// XXX: Rename to SetMeta() +func (enc *Encoder) MetaData(m MetaData) *Encoder { + enc.meta = &m + return enc +} + // Encode writes a TOML representation of the Go value to the Encoder's writer. // // An error is returned if the value given cannot be encoded to a valid TOML @@ -136,7 +156,39 @@ func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { return nil } +// Newline rules: + +// nocomment = "value" +// no_bl = 1 +// +// # With comment: has blank line before the comment, and one after the key. +// with_c = 2 +// +// asd = 1 +// qwe = 2 # After comment: no extra newline +// zxc = 3 +// +// [tbl] # Always has newline before it.key +// +// # With comment +// [tbl2] +// key1 = 123 +// +// func (enc *Encoder) encode(key Key, rv reflect.Value) { + extraNL := false + if enc.meta != nil && enc.meta.comments != nil { + comments := enc.meta.comments[key.String()] + for _, c := range comments { + if c.where == commentDoc { + extraNL = true + enc.w.WriteString("# ") + enc.w.WriteString(strings.ReplaceAll(c.text, "\n", "\n# ")) + enc.newline(1) + } + } + } + // Special case: time needs to be in ISO8601 format. // // Special case: if we can marshal the type to text, then we used that. This @@ -145,112 +197,182 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { switch t := rv.Interface().(type) { case time.Time, encoding.TextMarshaler, Marshaler: enc.writeKeyValue(key, rv, false) - return // TODO: #76 would make this superfluous after implemented. + // TODO: remove in v2 case Primitive: enc.encode(key, reflect.ValueOf(t.undecoded)) - return - } + default: - k := rv.Kind() - switch k { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, - reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, - reflect.Uint64, - reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: - enc.writeKeyValue(key, rv, false) - case reflect.Array, reflect.Slice: - if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { - enc.eArrayOfTables(key, rv) - } else { + k := rv.Kind() + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, + reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: enc.writeKeyValue(key, rv, false) + case reflect.Array, reflect.Slice: + if typeEqual(ArrayTable{}, tomlTypeOfGo(rv)) { + enc.eArrayOfTables(key, rv) + } else { + enc.writeKeyValue(key, rv, false) + } + case reflect.Interface: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Map: + if rv.IsNil() { + return + } + enc.eTable(key, rv) + case reflect.Ptr: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Struct: + enc.eTable(key, rv) + default: + encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k)) } - case reflect.Interface: - if rv.IsNil() { - return - } - enc.encode(key, rv.Elem()) - case reflect.Map: - if rv.IsNil() { - return - } - enc.eTable(key, rv) - case reflect.Ptr: - if rv.IsNil() { - return + } + + // Write comments after the key. + if enc.meta != nil && enc.meta.comments != nil { + comments := enc.meta.comments[key.String()] + for _, c := range comments { + if c.where == commentComment { + enc.w.WriteString(" # ") + enc.w.WriteString(strings.ReplaceAll(c.text, "\n", "\n# ")) + enc.newline(1) + } } - enc.encode(key, rv.Elem()) - case reflect.Struct: - enc.eTable(key, rv) - default: - encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k)) } + + enc.newline(1) + if extraNL { + enc.newline(1) + } +} + +func (enc *Encoder) writeInt(typ tomlType, v uint64) { + var ( + iTyp = asInt(typ) + base = int(iTyp.Base) + ) + switch iTyp.Base { + case 0: + base = 10 + case 2: + enc.wf("0b") + case 8: + enc.wf("0o") + case 16: + enc.wf("0x") + } + + n := strconv.FormatUint(uint64(v), base) + if base != 10 && iTyp.Width > 0 && len(n) < int(iTyp.Width) { + enc.wf(strings.Repeat("0", int(iTyp.Width)-len(n))) + } + enc.wf(n) } // eElement encodes any value that can be an array element. -func (enc *Encoder) eElement(rv reflect.Value) { +func (enc *Encoder) eElement(rv reflect.Value, typ tomlType) { + //fmt.Printf("ENC %T -> %s -> %[1]v\n", rv.Interface(), typ) + switch v := rv.Interface().(type) { case time.Time: // Using TextMarshaler adds extra quotes, which we don't want. - format := time.RFC3339Nano - switch v.Location() { - case internal.LocalDatetime: + format := "" + switch asDatetime(typ).Format { + case 0: // Undefined, check for special TZ. + format = time.RFC3339Nano + switch v.Location() { + case internal.LocalDatetime: + format = "2006-01-02T15:04:05.999999999" + case internal.LocalDate: + format = "2006-01-02" + case internal.LocalTime: + format = "15:04:05.999999999" + } + + case DatetimeFormatFull: + format = time.RFC3339Nano + case DatetimeFormatLocal: format = "2006-01-02T15:04:05.999999999" - case internal.LocalDate: + case DatetimeFormatDate: format = "2006-01-02" - case internal.LocalTime: + case DatetimeFormatTime: format = "15:04:05.999999999" - } - switch v.Location() { default: - enc.wf(v.Format(format)) - case internal.LocalDatetime, internal.LocalDate, internal.LocalTime: - enc.wf(v.In(time.UTC).Format(format)) + encPanic(fmt.Errorf("Invalid datetime format: %v", asDatetime(typ).Format)) + } + + //fmt.Printf("ENC %T -> %s -> %[1]v\n", rv.Interface(), typ) + //fmt.Println("XXX", asDatetime(typ).Format) + if format != time.RFC3339Nano { + //v = v.In(time.UTC) } + + //switch v.Location() { + //default: + enc.wf(v.Format(format)) + //case internal.LocalDatetime, internal.LocalDate, internal.LocalTime: + // enc.wf(v.In(time.UTC).Format(format)) + //} return case Marshaler: s, err := v.MarshalTOML() if err != nil { encPanic(err) } - enc.writeQuoted(string(s)) + enc.writeQuoted(string(s), asString(typ)) return case encoding.TextMarshaler: s, err := v.MarshalText() if err != nil { encPanic(err) } - enc.writeQuoted(string(s)) + enc.writeQuoted(string(s), asString(typ)) return } switch rv.Kind() { - case reflect.String: - enc.writeQuoted(rv.String()) case reflect.Bool: enc.wf(strconv.FormatBool(rv.Bool())) + case reflect.String: + enc.writeQuoted(rv.String(), asString(typ)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - enc.wf(strconv.FormatInt(rv.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - enc.wf(strconv.FormatUint(rv.Uint(), 10)) - case reflect.Float32: - f := rv.Float() - if math.IsNaN(f) { - enc.wf("nan") - } else if math.IsInf(f, 0) { - enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) - } else { - enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32))) + v := rv.Int() + if v < 0 { // Make sure sign is before "0x". + enc.wf("-") + v = -v } - case reflect.Float64: + enc.writeInt(typ, uint64(v)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + enc.writeInt(typ, rv.Uint()) + + case reflect.Float32, reflect.Float64: f := rv.Float() if math.IsNaN(f) { enc.wf("nan") } else if math.IsInf(f, 0) { enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) } else { - enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64))) + n := 64 + if rv.Kind() == reflect.Float32 { + n = 32 + } + if asFloat(typ).Exponent { + enc.wf(strconv.FormatFloat(f, 'e', -1, n)) + } else { + enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, n))) + } } + case reflect.Array, reflect.Slice: enc.eArrayOrSliceElement(rv) case reflect.Struct: @@ -258,7 +380,7 @@ func (enc *Encoder) eElement(rv reflect.Value) { case reflect.Map: enc.eMap(nil, rv, true) case reflect.Interface: - enc.eElement(rv.Elem()) + enc.eElement(rv.Elem(), typ) default: encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface())) } @@ -273,8 +395,21 @@ func floatAddDecimal(fstr string) string { return fstr } -func (enc *Encoder) writeQuoted(s string) { - enc.wf("\"%s\"", dblQuotedReplacer.Replace(s)) +func (enc *Encoder) writeQuoted(s string, typ String) { + if typ.Literal { + if typ.Multiline { + enc.wf("'''%s'''\n", s) + } else { + enc.wf(`'%s'`, s) + } + } else { + if typ.Multiline { + enc.wf(`"""%s"""`+"\n", + strings.ReplaceAll(dblQuotedReplacer.Replace(s), "\\n", "\n")) + } else { + enc.wf(`"%s"`, dblQuotedReplacer.Replace(s)) + } + } } func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { @@ -282,7 +417,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { enc.wf("[") for i := 0; i < length; i++ { elem := rv.Index(i) - enc.eElement(elem) + enc.eElement(elem, nil) // XXX: add type if i != length-1 { enc.wf(", ") } @@ -299,22 +434,21 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { if isNil(trv) { continue } - enc.newline() + + enc.newline(2) enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll()) - enc.newline() + enc.newline(1) enc.eMapOrStruct(key, trv, false) } } func (enc *Encoder) eTable(key Key, rv reflect.Value) { - if len(key) == 1 { - // Output an extra newline between top-level tables. - // (The newline isn't written if nothing else has been written though.) - enc.newline() + if len(key) == 1 { // Output an extra newline between top-level tables. + enc.newline(2) } if len(key) > 0 { enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll()) - enc.newline() + enc.newline(1) } enc.eMapOrStruct(key, rv, false) } @@ -501,46 +635,46 @@ func tomlTypeOfGo(rv reflect.Value) tomlType { } switch rv.Kind() { case reflect.Bool: - return tomlBool + return Bool{} case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return tomlInteger + return Int{} case reflect.Float32, reflect.Float64: - return tomlFloat + return Float{} case reflect.Array, reflect.Slice: - if typeEqual(tomlHash, tomlArrayType(rv)) { - return tomlArrayHash + if typeEqual(Table{}, tomlArrayType(rv)) { + return ArrayTable{} } - return tomlArray + return Array{} case reflect.Ptr, reflect.Interface: return tomlTypeOfGo(rv.Elem()) case reflect.String: - return tomlString + return String{} case reflect.Map: - return tomlHash + return Table{} case reflect.Struct: switch rv.Interface().(type) { case time.Time: - return tomlDatetime + return Datetime{} case encoding.TextMarshaler: - return tomlString + return String{} default: // Someone used a pointer receiver: we can make it work for pointer // values. if rv.CanAddr() { _, ok := rv.Addr().Interface().(encoding.TextMarshaler) if ok { - return tomlString + return String{} } } - return tomlHash + return Table{} } default: _, ok := rv.Interface().(encoding.TextMarshaler) if ok { - return tomlString + return String{} } encPanic(errors.New("unsupported type: " + rv.Kind().String())) panic("") // Need *some* return value @@ -620,9 +754,23 @@ func isEmpty(rv reflect.Value) bool { return false } -func (enc *Encoder) newline() { - if enc.hasWritten { - enc.wf("\n") +// newline ensures there are n newlines here. +func (enc *Encoder) newline(n int) { + // Don't write any newlines at the top of the file. + if !enc.hasWritten { + return + } + + w := n - enc.wroteNL + if w <= 0 { + return + } + + enc.wroteNL += w + //enc.wf(strings.Repeat("\n", w)) + _, err := enc.w.Write(bytes.Repeat([]byte("\n"), w)) + if err != nil { + encPanic(err) } } @@ -637,16 +785,22 @@ func (enc *Encoder) newline() { // │ ┌───┐ ┌─────┐│ // v v v v vv // key = {k = v, k2 = v2} -// func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { if len(key) == 0 { encPanic(errNoKey) } enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) - enc.eElement(val) - if !inline { - enc.newline() + + var typ tomlType + if enc.meta != nil { + if t, ok := enc.meta.types[key.String()]; ok { + typ = t + } } + enc.eElement(val, typ) + // if !inline { + // enc.newline() + // } } func (enc *Encoder) wf(format string, v ...interface{}) { @@ -654,6 +808,7 @@ func (enc *Encoder) wf(format string, v ...interface{}) { if err != nil { encPanic(err) } + enc.wroteNL = 0 enc.hasWritten = true } diff --git a/encode_test.go b/encode_test.go index 65aa10a0..74a4aa61 100644 --- a/encode_test.go +++ b/encode_test.go @@ -11,6 +11,137 @@ import ( "time" ) +// Copy from _example/example.go +type ( + example struct { + Title string `toml:"title"` + Integers []int `toml:"integers"` + //Times []fmtTime `toml:"times"` + Times []time.Time `toml:"times"` + Duration []duration `toml:"duration"` + //Distros []distro `toml:"distros"` + //Servers map[string]server `toml:"servers"` + //Characters map[string][]struct { + // Name string `toml:"name"` + // Rank string `toml:"rank"` + //} `toml:"characters"` + } + + server struct { + IP string `toml:"ip"` + Hostname string `toml:"hostname"` + Enabled bool `toml:"enabled"` + } + + distro struct { + Name string `toml:"name"` + Packages string `toml:"packages"` + } + + duration struct{ time.Duration } + //fmtTime struct{ time.Time } +) + +func (d *duration) UnmarshalText(text []byte) (err error) { + d.Duration, err = time.ParseDuration(string(text)) + return err +} + +func (d duration) MarshalText() ([]byte, error) { + return []byte(d.Duration.String()), nil +} + +//func (t fmtTime) String() string { +// f := "2006-01-02 15:04:05.999999999" +// if t.Time.Hour() == 0 { +// f = "2006-01-02" +// } +// if t.Time.Year() == 0 { +// f = "15:04:05.999999999" +// } +// if t.Time.Location() == time.UTC { +// f += " UTC" +// } else { +// f += " -0700" +// } +// return t.Time.Format(`"` + f + `"`) +//} + +func TestXXX(t *testing.T) { + { + var m struct { + // TODO: this doesn't work if the `toml:"d"` struct tag isn't + // present. + // + // In WriteKeyValue() it uses key.String(), which is "D2" rather + // than "d2" as it should be. + // + // This is set wrong: it should be set to D2. Actually, both are + // "correct", since the TOML has "d2", but it will write as "D2". + // Maybe add helper or something? + // + // This is already a problem in the existing implementation, but I + // guess not too many people use IsDefined() etc. + D time.Time `toml:"d"` + //D2 time.Time + } + //meta, _ := Decode("d = 2020-01-02\nd2 = 2020-01-02", &m) + meta, _ := Decode("d = 2020-01-02", &m) + //fmt.Printf("d → %s %s %#v\n", m.D, m.D.Location(), meta.TypeInfo("d")) + //fmt.Printf("d2 → %s %s %#v\n", m.D2, m.D2.Location(), meta.TypeInfo("d2")) + //fmt.Printf("D2 → %s %s %#v\n\n", m.D2, m.D2.Location(), meta.TypeInfo("D2")) + + // Wrong because unifyText() doesn't do the right thing? Used to work? + // Hmm... + NewEncoder(os.Stdout).Encode(m) + // fmt.Println() + // NewEncoder(os.Stdout).MetaData(meta).Encode(m) + _ = meta + } + + return + var decoded example + meta, err := DecodeFile("_example/example.toml", &decoded) + if err != nil { + t.Fatal(err) + } + + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + enc.MetaData(meta) + err = enc.Encode(decoded) + if err != nil { + t.Fatal(err) + } + + fmt.Println("types") + for k, v := range meta.types { + fmt.Printf(" %-24s %v\n", k, v) + } + fmt.Println() + + fmt.Println("keys") + for _, k := range meta.keys { + fmt.Printf(" %s\n", k) + } + + fmt.Println("mapping") + for k, v := range meta.mapping { + fmt.Printf(" %-24s %v\n", k, v) + } + fmt.Println() + + //fmt.Println("comments") + //for k, v := range meta.comments { + // fmt.Printf(" %-24s %v\n", k, v) + //} + //fmt.Println() + + fmt.Println(strings.Repeat("-", 60)) + fmt.Print(buf) + fmt.Println(strings.Repeat("-", 60)) +} + func TestEncodeRoundTrip(t *testing.T) { type Config struct { Age int @@ -468,6 +599,46 @@ func TestEncode32bit(t *testing.T) { nil) } +func TestEncodeHints(t *testing.T) { + return + foo := struct { + ML string `toml:"ml"` + Lit string `toml:"lit"` + Cmt string `toml:"cmt"` + N int `toml:"n"` + N2 int `toml:"n2"` + F1 float64 `toml:"f1"` + D1 time.Time `toml:"d1"` + }{} + + meta, err := NewDecoder(strings.NewReader(` + ml = """ MULTI """ + lit = 'asd' + # A test comment. + cmt = ''' asd ''' + n = 0x42 + n2 = +11231 + f1 = 2e-2 + d1 = 15:19:11 + `)).Decode(&foo) + if err != nil { + t.Fatal(err) + } + + meta.Doc("ml", "Hello").Comment("ml", "inline") + meta.SetType("n", Int{Width: 4, Base: 16}) + + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + enc.MetaData(meta) + err = enc.Encode(foo) + if err != nil { + t.Fatal(err) + } + + fmt.Println(buf.String()) +} + func encodeExpected(t *testing.T, label string, val interface{}, want string, wantErr error) { t.Helper() diff --git a/internal/tag/add.go b/internal/tag/add.go index 88f69503..76787f4a 100644 --- a/internal/tag/add.go +++ b/internal/tag/add.go @@ -5,11 +5,12 @@ import ( "math" "time" + "github.com/BurntSushi/toml" "github.com/BurntSushi/toml/internal" ) // Add JSON tags to a data structure as expected by toml-test. -func Add(key string, tomlData interface{}) interface{} { +func Add(meta toml.MetaData, key string, tomlData interface{}) interface{} { // Switch on the data type. switch orig := tomlData.(type) { default: @@ -20,7 +21,7 @@ func Add(key string, tomlData interface{}) interface{} { case map[string]interface{}: typed := make(map[string]interface{}, len(orig)) for k, v := range orig { - typed[k] = Add(k, v) + typed[k] = Add(meta, k, v) } return typed @@ -29,26 +30,41 @@ func Add(key string, tomlData interface{}) interface{} { case []map[string]interface{}: typed := make([]map[string]interface{}, len(orig)) for i, v := range orig { - typed[i] = Add("", v).(map[string]interface{}) + typed[i] = Add(meta, "", v).(map[string]interface{}) } return typed case []interface{}: typed := make([]interface{}, len(orig)) for i, v := range orig { - typed[i] = Add("", v) + typed[i] = Add(meta, "", v) } return typed // Datetime: tag as datetime. case time.Time: - switch orig.Location() { + dtFmt := toml.DatetimeFormatFull + if dt, ok := meta.TypeInfo(key).(toml.Datetime); ok { + dtFmt = dt.Format + } + switch dtFmt { default: + panic(fmt.Sprintf("unexpected datetime format: %#v for %q", dtFmt, key)) + case toml.DatetimeFormatFull: + switch orig.Location() { + case internal.LocalDatetime: + return tag("datetime-local", orig.Format("2006-01-02T15:04:05.999999999")) + case internal.LocalDate: + return tag("date-local", orig.Format("2006-01-02")) + case internal.LocalTime: + return tag("time-local", orig.Format("15:04:05.999999999")) + } + return tag("datetime", orig.Format("2006-01-02T15:04:05.999999999Z07:00")) - case internal.LocalDatetime: + case toml.DatetimeFormatLocal: return tag("datetime-local", orig.Format("2006-01-02T15:04:05.999999999")) - case internal.LocalDate: + case toml.DatetimeFormatDate: return tag("date-local", orig.Format("2006-01-02")) - case internal.LocalTime: + case toml.DatetimeFormatTime: return tag("time-local", orig.Format("15:04:05.999999999")) } diff --git a/internal/tag/rm.go b/internal/tag/rm.go index a8903fd1..67620522 100644 --- a/internal/tag/rm.go +++ b/internal/tag/rm.go @@ -78,6 +78,21 @@ func untag(typed map[string]interface{}) (interface{}, error) { return nil, fmt.Errorf("untag: %w", err) } return f, nil + + // XXX: this loses the "meta" information that's required. + // this is a bit annoying: the datetime is the only type that requires + // access to the metadata to be semantically correct. All the other values + // have different notations, but are semantically identical (0x10 == 16). + // + // Maybe add back the special timezones we used before, so the time.Time + // is "self-contained"? + // + // When decoding -> set both meta and type + // + // When encoding -> use meta if set, falling back to the TZ. + // + // time.Now() will encode as "full", unless meta is set. + // Decoding a time and setting meta to "full" will encode as such. case "datetime": return parseTime(v, "2006-01-02T15:04:05.999999999Z07:00", nil) case "datetime-local": @@ -86,6 +101,7 @@ func untag(typed map[string]interface{}) (interface{}, error) { return parseTime(v, "2006-01-02", internal.LocalDate) case "time-local": return parseTime(v, "15:04:05.999999999", internal.LocalTime) + case "bool": switch v { case "true": diff --git a/lex.go b/lex.go index b23302fb..63c868c0 100644 --- a/lex.go +++ b/lex.go @@ -15,7 +15,11 @@ const ( itemError itemType = iota itemNIL // used in the parser to indicate no type itemEOF - itemText + itemCommentStart + itemComment + itemKeyStart + itemKeyEnd + itemKey itemString itemRawString itemMultilineString @@ -24,15 +28,12 @@ const ( itemInteger itemFloat itemDatetime - itemArray // the start of an array + itemArrayStart itemArrayEnd itemTableStart itemTableEnd itemArrayTableStart itemArrayTableEnd - itemKeyStart - itemKeyEnd - itemCommentStart itemInlineTableStart itemInlineTableEnd ) @@ -401,7 +402,7 @@ func lexBareName(lx *lexer) stateFn { return lexBareName } lx.backup() - lx.emit(itemText) + lx.emit(itemKey) return lx.pop() } @@ -500,7 +501,7 @@ func lexValue(lx *lexer) stateFn { switch r { case '[': lx.ignore() - lx.emit(itemArray) + lx.emit(itemArrayStart) return lexArrayValue case '{': lx.ignore() @@ -1121,8 +1122,8 @@ func lexBool(lx *lexer) stateFn { return lx.errorf("expected value but found %q instead", s) } -// lexCommentStart begins the lexing of a comment. It will emit -// itemCommentStart and consume no characters, passing control to lexComment. +// lexCommentStart begins the lexing of a comment. It will emit itemCommentStart +// and consume no characters, passing control to lexComment. func lexCommentStart(lx *lexer) stateFn { lx.ignore() lx.emit(itemCommentStart) @@ -1136,7 +1137,7 @@ func lexComment(lx *lexer) stateFn { switch r := lx.next(); { case isNL(r) || r == eof: lx.backup() - lx.emit(itemText) + lx.emit(itemComment) return lx.pop() case isControl(r): return lx.errorControlChar(r) @@ -1170,8 +1171,10 @@ func (itype itemType) String() string { return "NIL" case itemEOF: return "EOF" - case itemText: - return "Text" + case itemKey: + return "BareKey" + case itemComment: + return "Comment" case itemString, itemRawString, itemMultilineString, itemRawMultilineString: return "String" case itemBool: @@ -1190,8 +1193,8 @@ func (itype itemType) String() string { return "KeyStart" case itemKeyEnd: return "KeyEnd" - case itemArray: - return "Array" + case itemArrayStart: + return "ArrayStart" case itemArrayEnd: return "ArrayEnd" case itemCommentStart: diff --git a/meta.go b/meta.go index fecee5b5..fe9c4af9 100644 --- a/meta.go +++ b/meta.go @@ -1,16 +1,78 @@ package toml -import "strings" +import ( + "fmt" + "strings" +) -// MetaData allows access to meta information about TOML data that may not be -// inferable via reflection. In particular, whether a key has been defined and -// the TOML type of a key. +// MetaData allows access to meta information about TOML. +// +// It allows determining whether a key has been defined, the TOML type of a +// key, and how it's formatted. It also records comments in the TOML file. type MetaData struct { - mapping map[string]interface{} - types map[string]tomlType - keys []Key - decoded map[string]bool - context Key // Used only during decoding. + mapping map[string]interface{} + types map[string]tomlType // TOML types. + keys []Key // List of defined keys. + decoded map[string]bool // Decoded keys. + context Key // Used only during decoding. + comments map[string][]comment // Record comments. +} + +const ( + _ = iota + commentDoc // Above the key. + commentComment // "Inline" after the key. +) + +type comment struct { + where int + text string +} + +func NewMetaData() MetaData { + return MetaData{} +} + +type ( + Doc string + Comment string +) + +func (enc *MetaData) Key(key string, args ...interface{}) *MetaData { + for _, a := range args { + switch aa := a.(type) { + default: + panic(fmt.Sprintf("toml.MetaData.Key: unsupported type: %T", a)) + case tomlType: + enc.SetType(key, aa) + case Doc: + enc.Doc(key, string(aa)) + case Comment: + enc.Comment(key, string(aa)) + } + } + return enc +} + +func (enc *MetaData) SetType(key string, t tomlType) *MetaData { + enc.types[key] = t + return enc +} + +func (enc *MetaData) Doc(key string, doc string) *MetaData { + if enc.comments == nil { + enc.comments = make(map[string][]comment) + } + enc.comments[key] = append(enc.comments[key], comment{where: commentDoc, text: doc}) + return enc +} + +func (enc *MetaData) Comment(key string, doc string) *MetaData { + if enc.comments == nil { + enc.comments = make(map[string][]comment) + } + enc.comments[key] = append(enc.comments[key], comment{where: commentComment, text: doc}) + return enc } // IsDefined reports if the key exists in the TOML data. @@ -45,13 +107,21 @@ func (md *MetaData) IsDefined(key ...string) bool { // Type will return the empty string if given an empty key or a key that does // not exist. Keys are case sensitive. func (md *MetaData) Type(key ...string) string { - fullkey := strings.Join(key, ".") - if typ, ok := md.types[fullkey]; ok { - return typ.typeString() + if t, ok := md.types[Key(key).String()]; ok { + return t.String() } return "" } +func (md *MetaData) TypeInfo(key ...string) tomlType { + // TODO(v2): Type() would be a better name for this, but that's already + // used. We can change this to: + // + // meta.TypeInfo() → meta.Type() + // meta.IsDefined() → meta.Type() == nil + return md.types[Key(key).String()] +} + // Keys returns a slice of every key in the TOML data, including key groups. // // Each key is itself a slice, where the first element is the top of the diff --git a/parse.go b/parse.go index e0d62f83..d8204039 100644 --- a/parse.go +++ b/parse.go @@ -11,13 +11,16 @@ import ( ) type parser struct { - mapping map[string]interface{} - types map[string]tomlType - lx *lexer - - ordered []Key // List of keys in the order that they appear in the TOML data. - context Key // Full key for the current hash in scope. - currentKey string // Base key name for everything except hashes. + mapping map[string]interface{} + types map[string]tomlType + comments map[string][]comment + lx *lexer + + ordered []Key // List of keys in the order that they appear in the TOML data. + context Key // Full key for the current hash in scope. + currentKey string // Base key name for everything except hashes. + prevKey string + comment []string pos Position // Position implicits map[string]bool // Record implied keys (e.g. 'key.group.names'). } @@ -59,6 +62,7 @@ func parse(data string) (p *parser, err error) { p = &parser{ mapping: make(map[string]interface{}), types: make(map[string]tomlType), + comments: make(map[string][]comment), lx: lex(data), ordered: make([]Key, 0), implicits: make(map[string]bool), @@ -135,7 +139,21 @@ func (p *parser) assertEqual(expected, got itemType) { func (p *parser) topLevel(item item) { switch item.typ { case itemCommentStart: // # .. - p.expect(itemText) + text := p.expect(itemComment) + + // XXX: we need to associate this comment with a key: + // + // - If it's inline, associate with previous key. + // - If it's above a key, associate with next key. + // + // Memorize the comment if it's above a key (set p.doc), and associate + // that when we read the key. + // + // For inline keys we can use p.context + p.prevKey. + + p.comment = append(p.comment, text.val) + //k := append(p.context, p.prevKey).String() + //p.comments[k] = append(p.comments[k], comment{where: commentDoc, text: text.val}) case itemTableStart: // [ .. ] name := p.nextPos() @@ -146,7 +164,7 @@ func (p *parser) topLevel(item item) { p.assertEqual(itemTableEnd, name.typ) p.addContext(key, false) - p.setType("", tomlHash) + p.setType("", Table{}) p.ordered = append(p.ordered, key) case itemArrayTableStart: // [[ .. ]] name := p.nextPos() @@ -158,7 +176,7 @@ func (p *parser) topLevel(item item) { p.assertEqual(itemArrayTableEnd, name.typ) p.addContext(key, true) - p.setType("", tomlArrayHash) + p.setType("", ArrayTable{}) p.ordered = append(p.ordered, key) case itemKeyStart: // key = .. outerContext := p.context @@ -180,6 +198,13 @@ func (p *parser) topLevel(item item) { p.addImplicitContext(append(p.context, context[i:i+1]...)) } + if len(p.comment) > 0 { + for _, c := range p.comment { + p.comments[p.currentKey] = append(p.comments[p.currentKey], comment{where: commentDoc, text: c}) + } + p.comment = nil + } + /// Set value. val, typ := p.value(p.next(), false) p.set(p.currentKey, val, typ) @@ -187,7 +212,7 @@ func (p *parser) topLevel(item item) { /// Remove the context we added (preserving any context from [tbl] lines). p.context = outerContext - p.currentKey = "" + p.currentKey, p.prevKey = "", p.currentKey default: p.bug("Unexpected type at top level: %s", item.typ) } @@ -196,7 +221,7 @@ func (p *parser) topLevel(item item) { // Gets a string for a key (or part of a key in a table name). func (p *parser) keyString(it item) string { switch it.typ { - case itemText: + case itemKey, itemComment: return it.val case itemString, itemMultilineString, itemRawString, itemRawMultilineString: @@ -204,8 +229,8 @@ func (p *parser) keyString(it item) string { return s.(string) default: p.bug("Unexpected key type: %s", it.typ) + panic("unreachable") } - panic("unreachable") } var datetimeRepl = strings.NewReplacer( @@ -218,13 +243,13 @@ var datetimeRepl = strings.NewReplacer( func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) { switch it.typ { case itemString: - return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it) + return p.replaceEscapes(it, it.val), String{} case itemMultilineString: - return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), p.typeOfPrimitive(it) + return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), String{Multiline: true} case itemRawString: - return it.val, p.typeOfPrimitive(it) + return it.val, String{Literal: true} case itemRawMultilineString: - return stripFirstNewline(it.val), p.typeOfPrimitive(it) + return stripFirstNewline(it.val), String{Literal: true, Multiline: true} case itemInteger: return p.valueInteger(it) case itemFloat: @@ -232,15 +257,15 @@ func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) { case itemBool: switch it.val { case "true": - return true, p.typeOfPrimitive(it) + return true, Bool{} case "false": - return false, p.typeOfPrimitive(it) + return false, Bool{} default: p.bug("Expected boolean value, but got '%s'.", it.val) } case itemDatetime: return p.valueDatetime(it) - case itemArray: + case itemArrayStart: return p.valueArray(it) case itemInlineTableStart: return p.valueInlineTable(it, parentIsArray) @@ -261,17 +286,30 @@ func (p *parser) valueInteger(it item) (interface{}, tomlType) { num, err := strconv.ParseInt(it.val, 0, 64) if err != nil { // Distinguish integer values. Normally, it'd be a bug if the lexer - // provides an invalid integer, but it's possible that the number is - // out of range of valid values (which the lexer cannot determine). - // So mark the former as a bug but the latter as a legitimate user - // error. + // provides an invalid integer, but it's possible that the number is out + // of range of valid values (which the lexer cannot determine). So mark + // the former as a bug but the latter as a legitimate user error. if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val) } else { p.bug("Expected integer value, but got '%s'.", it.val) } } - return num, p.typeOfPrimitive(it) + + v := it.val + if len(v) > 0 && (v[0] == '-' || v[0] == '+') { + v = v[1:] + } + var base uint8 + switch { + case strings.HasPrefix(v, "0b"): + base = 2 + case strings.HasPrefix(v, "0o"): + base = 8 + case strings.HasPrefix(v, "0x"): + base = 16 + } + return num, Int{Base: base} } func (p *parser) valueFloat(it item) (interface{}, tomlType) { @@ -291,10 +329,8 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) { p.panicItemf(it, "Invalid float %q: cannot have leading zeroes", it.val) } if !numPeriodsOK(it.val) { - // As a special case, numbers like '123.' or '1.e2', - // which are valid as far as Go/strconv are concerned, - // must be rejected because TOML says that a fractional - // part consists of '.' followed by 1+ digits. + // Numbers like '123.' or '1.e2' are valid in Go/strconv, but not valid + // in TOML as a fractional part consists of '.' followed by 1+ digits. p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val) } val := strings.Replace(it.val, "_", "", -1) @@ -309,58 +345,60 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) { p.panicItemf(it, "Invalid float value: %q", it.val) } } - return num, p.typeOfPrimitive(it) + exp := false + if strings.ContainsAny(val, "eE") { + exp = true + } + return num, Float{Exponent: exp} } var dtTypes = []struct { fmt string zone *time.Location + f DatetimeFormat }{ - {time.RFC3339Nano, time.Local}, - {"2006-01-02T15:04:05.999999999", internal.LocalDatetime}, - {"2006-01-02", internal.LocalDate}, - {"15:04:05.999999999", internal.LocalTime}, + {time.RFC3339Nano, time.Local, DatetimeFormatFull}, + {"2006-01-02T15:04:05.999999999", internal.LocalDatetime, DatetimeFormatLocal}, + {"2006-01-02", internal.LocalDate, DatetimeFormatDate}, + {"15:04:05.999999999", internal.LocalTime, DatetimeFormatTime}, } func (p *parser) valueDatetime(it item) (interface{}, tomlType) { it.val = datetimeRepl.Replace(it.val) - var ( - t time.Time - ok bool - err error - ) for _, dt := range dtTypes { - t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone) + t, err := time.ParseInLocation(dt.fmt, it.val, dt.zone) if err == nil { - ok = true - break + fmt.Printf("Parsed with %s in %s\n %s\n", dt.fmt, dt.zone, t) + return t, Datetime{Format: dt.f} } } - if !ok { - p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val) - } - return t, p.typeOfPrimitive(it) + p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val) + panic("unreachable") } func (p *parser) valueArray(it item) (interface{}, tomlType) { - p.setType(p.currentKey, tomlArray) + p.setType(p.currentKey, Array{}) - // p.setType(p.currentKey, typ) var ( array []interface{} types []tomlType ) for it = p.next(); it.typ != itemArrayEnd; it = p.next() { if it.typ == itemCommentStart { - p.expect(itemText) + p.expect(itemComment) continue } val, typ := p.value(it, true) array = append(array, val) types = append(types, typ) + // XXX: types isn't used here, we need it to record the accurate type + // information. + // + // Not entirely sure how to best store this; could use "key[0]", + // "key[1]" notation, or maybe store it on the Array type? } - return array, tomlArray + return array, Array{} } func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) { @@ -380,7 +418,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom /// Loop over all table key/value pairs. for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() { if it.typ == itemCommentStart { - p.expect(itemText) + p.expect(itemComment) continue } @@ -413,7 +451,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom } p.context = outerContext p.currentKey = outerKey - return hash, tomlHash + return hash, Table{} } // numHasLeadingZero checks if this number has leading zeroes, allowing for '0', @@ -605,7 +643,10 @@ func (p *parser) setType(key string, typ tomlType) { func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = true } func (p *parser) removeImplicit(key Key) { p.implicits[key.String()] = false } func (p *parser) isImplicit(key Key) bool { return p.implicits[key.String()] } -func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray } +func (p *parser) isArray(key Key) bool { + _, ok := p.types[key.String()].(Array) + return ok +} func (p *parser) addImplicitContext(key Key) { p.addImplicit(key) p.addContext(key, false) diff --git a/toml_test.go b/toml_test.go index 84f6403a..3458b5f3 100644 --- a/toml_test.go +++ b/toml_test.go @@ -182,7 +182,8 @@ func (p parser) Encode(input string) (output string, outputIsError bool, retErr } buf := new(bytes.Buffer) - err = toml.NewEncoder(buf).Encode(rm) + enc := toml.NewEncoder(buf) + err = enc.Encode(rm) if err != nil { return err.Error(), true, retErr } @@ -203,11 +204,12 @@ func (p parser) Decode(input string) (output string, outputIsError bool, retErr }() var d interface{} - if _, err := toml.Decode(input, &d); err != nil { + meta, err := toml.Decode(input, &d) + if err != nil { return err.Error(), true, retErr } - j, err := json.MarshalIndent(tag.Add("", d), "", " ") + j, err := json.MarshalIndent(tag.Add(meta, "", d), "", " ") if err != nil { return "", false, err } diff --git a/type_toml.go b/type_toml.go index 4e90d773..c804f6a4 100644 --- a/type_toml.go +++ b/type_toml.go @@ -1,70 +1,146 @@ package toml -// tomlType represents any Go type that corresponds to a TOML type. -// While the first draft of the TOML spec has a simplistic type system that -// probably doesn't need this level of sophistication, we seem to be militating -// toward adding real composite types. +// tomlType represents a TOML type. type tomlType interface { - typeString() string + tomlType() + String() string } -// typeEqual accepts any two types and returns true if they are equal. -func typeEqual(t1, t2 tomlType) bool { - if t1 == nil || t2 == nil { - return false +type TomlType = tomlType // XXX + +type ( + // Bool represents a TOML boolean. + Bool struct{} + + // String represents a TOML string. + String struct { + Literal bool // As literal string ('..'). + Multiline bool // As multi-line string ("""..""" or '''..'''). } - return t1.typeString() == t2.typeString() -} -func typeIsTable(t tomlType) bool { - return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash) -} + // Int represents a TOML integer. + Int struct { + Base uint8 // Base 2, 8, 10, 16, or 0 (same as 10). + Width uint8 // Print leading zeros up to width; ignored for base 10. + } + + // Float represents a TOML float. + Float struct { + Exponent bool // As exponent notation. + } -type tomlBaseType string + // Datetime represents a TOML datetime. + Datetime struct { + Format DatetimeFormat // enum: local, date, time + } -func (btype tomlBaseType) typeString() string { - return string(btype) -} + // DatetimeFormat controls the format to print a datetime. + DatetimeFormat uint8 -func (btype tomlBaseType) String() string { - return btype.typeString() + // Table represents a TOML table. + Table struct { + Inline bool // As inline table. + //Dotted bool + //Merge bool + } + + // Array represents a TOML array. + Array struct { + SingleLine bool // Print on single line. + } + + // ArrayTable represents a TOML array table ([[...]]). + ArrayTable struct { + Inline bool // As inline x = [{..}] rather than [[..]] + } +) + +func (d DatetimeFormat) String() string { + switch d { + default: + return "" + case DatetimeFormatFull: + return "full" + case DatetimeFormatLocal: + return "local" + case DatetimeFormatDate: + return "date" + case DatetimeFormatTime: + return "time" + } } -var ( - tomlInteger tomlBaseType = "Integer" - tomlFloat tomlBaseType = "Float" - tomlDatetime tomlBaseType = "Datetime" - tomlString tomlBaseType = "String" - tomlBool tomlBaseType = "Bool" - tomlArray tomlBaseType = "Array" - tomlHash tomlBaseType = "Hash" - tomlArrayHash tomlBaseType = "ArrayHash" +const ( + _ DatetimeFormat = iota + DatetimeFormatFull // 2021-11-20T15:16:17+01:00 + DatetimeFormatLocal // 2021-11-20T15:16:17 + DatetimeFormatDate // 2021-11-20 + DatetimeFormatTime // 15:16:17 ) -// typeOfPrimitive returns a tomlType of any primitive value in TOML. -// Primitive values are: Integer, Float, Datetime, String and Bool. -// -// Passing a lexer item other than the following will cause a BUG message -// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime. -func (p *parser) typeOfPrimitive(lexItem item) tomlType { - switch lexItem.typ { - case itemInteger: - return tomlInteger - case itemFloat: - return tomlFloat - case itemDatetime: - return tomlDatetime - case itemString: - return tomlString - case itemMultilineString: - return tomlString - case itemRawString: - return tomlString - case itemRawMultilineString: - return tomlString - case itemBool: - return tomlBool +func (t Bool) tomlType() {} +func (t String) tomlType() {} +func (t Int) tomlType() {} +func (t Float) tomlType() {} +func (t Datetime) tomlType() {} +func (t Table) tomlType() {} +func (t Array) tomlType() {} +func (t ArrayTable) tomlType() {} +func (t Bool) String() string { return "Bool" } +func (t String) String() string { return "String" } +func (t Int) String() string { return "Integer" } +func (t Float) String() string { return "Float" } +func (t Datetime) String() string { return "Datetime" } +func (t Table) String() string { return "Table" } +func (t Array) String() string { return "Array" } +func (t ArrayTable) String() string { return "ArrayTable" } + +// meta.types may not be defined for a key, so return a zero value. +func asString(t tomlType) String { + if t == nil { + return String{} + } + return t.(String) +} +func asInt(t tomlType) Int { + if t == nil { + return Int{} + } + return t.(Int) +} +func asFloat(t tomlType) Float { + if t == nil { + return Float{} + } + return t.(Float) +} +func asDatetime(t tomlType) Datetime { + if t == nil { + return Datetime{} + } + return t.(Datetime) +} +func asTable(t tomlType) Table { + if t == nil { + return Table{} + } + return t.(Table) +} +func asArray(t tomlType) Array { + if t == nil { + return Array{} } - p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) - panic("unreachable") + return t.(Array) +} + +// typeEqual accepts any two types and returns true if they are equal. +func typeEqual(t1, t2 tomlType) bool { + if t1 == nil || t2 == nil { + return false + } + return t1.String() == t2.String() +} + +func typeIsTable(t tomlType) bool { + return typeEqual(t, Table{}) || typeEqual(t, ArrayTable{}) }