From 82718c9031b56344b0c230bc5ef060a9f20e40b6 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 26 Jun 2024 23:30:08 +1000 Subject: [PATCH] bencode: Support unmarshalling into maps with non-string key types Fixes #952 --- bencode/decode.go | 47 +++++++++++++++++++++++-------- bencode/decode_test.go | 2 +- tests/issue-952/issue-952_test.go | 38 +++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 13 deletions(-) create mode 100644 tests/issue-952/issue-952_test.go diff --git a/bencode/decode.go b/bencode/decode.go index 3839b849c2..ce1f8b465f 100644 --- a/bencode/decode.go +++ b/bencode/decode.go @@ -281,7 +281,7 @@ type dictField struct { } // Returns specifics for parsing a dict field value. -func getDictField(dict reflect.Type, key string) (_ dictField, err error) { +func getDictField(dict reflect.Type, key reflect.Value) (_ dictField, err error) { // get valuev as a map value or as a struct field switch k := dict.Kind(); k { case reflect.Map: @@ -293,13 +293,18 @@ func getDictField(dict reflect.Type, key string) (_ dictField, err error) { mapValue.Set(reflect.MakeMap(dict)) } // Assigns the value into the map. - // log.Printf("map type: %v", mapValue.Type()) - mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value) + mapValue.SetMapIndex(key, value) } }, }, nil case reflect.Struct: - return getStructFieldForKey(dict, key), nil + if key.Kind() != reflect.String { + // This doesn't make sense for structs. They have to use strings. If they didn't they + // should at least have things that convert to strings trivially and somehow much the + // bencode tag. + panic(key) + } + return getStructFieldForKey(dict, key.String()), nil // if sf.r.PkgPath != "" { // panic(&UnmarshalFieldError{ // Key: key, @@ -382,11 +387,29 @@ func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) { return } +var structKeyType = reflect.TypeFor[string]() + +func keyType(v reflect.Value) reflect.Type { + switch v.Kind() { + case reflect.Map: + return v.Type().Key() + case reflect.Struct: + return structKeyType + default: + return nil + } +} + func (d *Decoder) parseDict(v reflect.Value) error { - // At this point 'd' byte was consumed, now read key/value pairs + // At this point 'd' byte was consumed, now read key/value pairs. + + // The key type does not need to be a string for maps. + keyType := keyType(v) + if keyType == nil { + return fmt.Errorf("cannot parse dicts into %v", v.Type()) + } for { - var keyStr string - keyValue := reflect.ValueOf(&keyStr).Elem() + keyValue := reflect.New(keyType).Elem() ok, err := d.parseValue(keyValue) if err != nil { return fmt.Errorf("error parsing dict key: %w", err) @@ -395,7 +418,7 @@ func (d *Decoder) parseDict(v reflect.Value) error { return nil } - df, err := getDictField(v.Type(), keyStr) + df, err := getDictField(v.Type(), keyValue) if err != nil { return fmt.Errorf("parsing bencode dict into %v: %w", v.Type(), err) } @@ -406,10 +429,10 @@ func (d *Decoder) parseDict(v reflect.Value) error { var if_ interface{} if_, ok = d.parseValueInterface() if if_ == nil { - return fmt.Errorf("error parsing value for key %q", keyStr) + return fmt.Errorf("error parsing value for key %q", keyValue) } if !ok { - return fmt.Errorf("missing value for key %q", keyStr) + return fmt.Errorf("missing value for key %q", keyValue) } continue } @@ -419,11 +442,11 @@ func (d *Decoder) parseDict(v reflect.Value) error { if err != nil { var target *UnmarshalTypeError if !(errors.As(err, &target) && df.Tags.IgnoreUnmarshalTypeError()) { - return fmt.Errorf("parsing value for key %q: %w", keyStr, err) + return fmt.Errorf("parsing value for key %q: %w", keyValue, err) } } if !ok { - return fmt.Errorf("missing value for key %q", keyStr) + return fmt.Errorf("missing value for key %q", keyValue) } df.Get(v)(setValue) } diff --git a/bencode/decode_test.go b/bencode/decode_test.go index 4d05d2b332..a350beea31 100644 --- a/bencode/decode_test.go +++ b/bencode/decode_test.go @@ -184,7 +184,7 @@ func TestDecodeDictIntoUnsupported(t *testing.T) { c := qt.New(t) err := Unmarshal([]byte("d1:a1:be"), &i) t.Log(err) - c.Check(err, qt.Not(qt.IsNil)) + c.Check(err, qt.IsNotNil) } func TestUnmarshalDictKeyNotString(t *testing.T) { diff --git a/tests/issue-952/issue-952_test.go b/tests/issue-952/issue-952_test.go new file mode 100644 index 0000000000..8cf154119c --- /dev/null +++ b/tests/issue-952/issue-952_test.go @@ -0,0 +1,38 @@ +package issue_952 + +import ( + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/metainfo" + "github.com/anacrolix/torrent/types/infohash" + qt "github.com/frankban/quicktest" + "testing" +) + +type scrapeResponse struct { + Files map[metainfo.Hash]scrapeResponseFile `bencode:"files"` +} + +type scrapeResponseFile struct { + Complete int `bencode:"complete"` + Downloaded int `bencode:"downloaded"` + Incomplete int `bencode:"incomplete"` +} + +// This tests unmarshalling to a map with a non-string dict key. +func TestUnmarshalStringToByteArray(t *testing.T) { + var s scrapeResponse + const hashStr = "\x05a~F\xfd{c\xd1`\xb8\xd9\x89\xceM\xb9t\x1d\\\x0b\xded" + err := bencode.Unmarshal([]byte("d5:filesd20:\x05a~F\xfd{c\xd1`\xb8\xd9\x89\xceM\xb9t\x1d\\\x0b\xded9:completedi1e10:downloadedi1eeee"), &s) + c := qt.New(t) + c.Assert(err, qt.IsNil) + c.Check(s.Files, qt.HasLen, 1) + file, ok := s.Files[(infohash.T)([]byte(hashStr))] + c.Assert(ok, qt.IsTrue) + c.Check(file, qt.Equals, scrapeResponseFile{ + // Note that complete is misspelled in the example. I don't know why. + Complete: 0, + Downloaded: 1, + Incomplete: 0, + }) + +}