diff --git a/patch.go b/patch.go index 845ac37..9538704 100644 --- a/patch.go +++ b/patch.go @@ -1,10 +1,7 @@ package yamlpatch import ( - "errors" "fmt" - "strconv" - "strings" yaml "gopkg.in/yaml.v2" ) @@ -43,12 +40,14 @@ func (p Patch) Apply(doc []byte) ([]byte, error) { return nil, err } + pathfinder := NewPathFinder(iface) + for _, op := range p { if op.Path.ContainsExtendedSyntax() { var paths []string - paths, err = canonicalPaths(op.Path, iface) - if err != nil { - return nil, err + paths = pathfinder.Find(string(op.Path)) + if paths == nil { + return nil, fmt.Errorf("could not expand pointer: %s", op.Path) } for _, path := range paths { @@ -69,84 +68,3 @@ func (p Patch) Apply(doc []byte) ([]byte, error) { return yaml.Marshal(c) } - -func canonicalPaths(opPath OpPath, obj interface{}) ([]string, error) { - var prefix string - var paths []string - -OUTER_LOOP: - for _, part := range strings.Split(string(opPath), "/") { - if part == "" { - continue - } - - if kv := strings.Split(part, "="); len(kv) == 2 { - _, newPaths := findAllPaths(kv[0], kv[1], obj) - for _, path := range newPaths { - paths = append(paths, fmt.Sprintf("%s/%s", prefix, path)) - } - continue - } - - // this is an optimization to reduce recursive calls to findAllPaths - switch ot := obj.(type) { - case []interface{}: - if idx, err := strconv.Atoi(part); err == nil && idx >= 0 && idx <= len(ot)-1 { - prefix = fmt.Sprintf("%s/%d", prefix, idx) - obj = ot[idx] - continue - } - return nil, fmt.Errorf("invalid index given for array at %s", prefix) - case map[interface{}]interface{}: - for k, v := range ot { - if ks, ok := k.(string); ok && ks == part { - prefix = fmt.Sprintf("%s/%s", prefix, k) - obj = v - continue OUTER_LOOP - } - } - return nil, errors.New("path does not match structure") - } - } - - return paths, nil -} - -func findAllPaths(findKey, findValue string, data interface{}) (interface{}, []string) { - var paths []string - - switch dt := data.(type) { - case map[interface{}]interface{}: - for k, v := range dt { - switch vs := v.(type) { - case string: - if ks, ok := k.(string); ok && ks == findKey && vs == findValue { - return dt, nil - } - default: - _, subPaths := findAllPaths(findKey, findValue, v) - for i := range subPaths { - paths = append(paths, fmt.Sprintf("%s/%s", k, subPaths[i])) - } - } - } - case []interface{}: - for i, v := range dt { - if f, subPaths := findAllPaths(findKey, findValue, v); f != nil { - if len(subPaths) > 0 { - for _, subPath := range subPaths { - paths = append(paths, fmt.Sprintf("%d/%s", i, subPath)) - } - continue - } - paths = append(paths, fmt.Sprintf("%d", i)) - } - } - } - - if paths != nil { - return data, paths - } - - return nil, nil -} diff --git a/patch_test.go b/patch_test.go index 5245563..973cbff 100644 --- a/patch_test.go +++ b/patch_test.go @@ -338,6 +338,30 @@ waldo: - thud: boo - baz: quux corge: grault +`, + ), + Entry("a path that doesn't end with a composite key", + `--- +jobs: +- name: upgrade-opsmgr + serial: true + plan: + - get: pivnet-opsmgr + - put: something-else +`, + `--- +- op: replace + path: /jobs/name=upgrade-opsmgr/plan/1 + value: + get: something-else +`, + `--- +jobs: +- name: upgrade-opsmgr + serial: true + plan: + - get: pivnet-opsmgr + - get: something-else `, ), Entry("removes multiple entries in a single op", diff --git a/pathfinder.go b/pathfinder.go new file mode 100644 index 0000000..604e84b --- /dev/null +++ b/pathfinder.go @@ -0,0 +1,113 @@ +package yamlpatch + +import ( + "fmt" + "strconv" + "strings" +) + +// PathFinder can be used to find RFC6902-standard paths given non-standard +// (key=value) pointer syntax +type PathFinder struct { + root interface{} +} + +// NewPathFinder takes an interface that represents a YAML document and returns +// a new PathFinder +func NewPathFinder(iface interface{}) *PathFinder { + return &PathFinder{ + root: iface, + } +} + +// Find expands the given path into all matching paths, returning the canonical +// versions of those matching paths +func (p *PathFinder) Find(path string) []string { + parts := strings.Split(path, "/") + + if parts[1] == "" { + return []string{"/"} + } + + routes := map[string]interface{}{ + "": p.root, + } + + for _, part := range parts[1:] { + routes = find(part, routes) + } + + var paths []string + for k := range routes { + paths = append(paths, k) + } + + return paths +} + +func find(part string, routes map[string]interface{}) map[string]interface{} { + matches := map[string]interface{}{} + + for prefix, iface := range routes { + if strings.Contains(part, "=") { + kv := strings.Split(part, "=") + if newMatches := findAll(prefix, kv[0], kv[1], iface); len(newMatches) > 0 { + matches = newMatches + } + continue + } + + switch it := iface.(type) { + case map[interface{}]interface{}: + for k, v := range it { + if ks, ok := k.(string); ok && ks == part { + path := fmt.Sprintf("%s/%s", prefix, ks) + matches[path] = v + } + } + case []interface{}: + if idx, err := strconv.Atoi(part); err == nil && idx >= 0 && idx <= len(it)-1 { + path := fmt.Sprintf("%s/%d", prefix, idx) + matches[path] = it[idx] + } + default: + panic(fmt.Sprintf("don't know how to handle %T: %s", iface, iface)) + } + } + + return matches +} + +func findAll(prefix, findKey, findValue string, iface interface{}) map[string]interface{} { + matches := map[string]interface{}{} + + switch it := iface.(type) { + case map[interface{}]interface{}: + for k, v := range it { + if ks, ok := k.(string); ok { + switch vs := v.(type) { + case string: + if ks == findKey && vs == findValue { + return map[string]interface{}{ + prefix: it, + } + } + default: + for route, match := range findAll(fmt.Sprintf("%s/%s", prefix, ks), findKey, findValue, v) { + matches[route] = match + } + } + } + } + case []interface{}: + for i, v := range it { + for route, match := range findAll(fmt.Sprintf("%s/%d", prefix, i), findKey, findValue, v) { + matches[route] = match + } + } + default: + panic(fmt.Sprintf("don't know how to handle %T: %s", iface, iface)) + } + + return matches +} diff --git a/pathfinder_test.go b/pathfinder_test.go new file mode 100644 index 0000000..3e59925 --- /dev/null +++ b/pathfinder_test.go @@ -0,0 +1,69 @@ +package yamlpatch_test + +import ( + yamlpatch "github.com/krishicks/yaml-patch" + yaml "gopkg.in/yaml.v2" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Pathfinder", func() { + var pathfinder *yamlpatch.PathFinder + + BeforeEach(func() { + var iface interface{} + + bs := []byte(` +jobs: +- name: job1 + plan: + - get: A + args: + - arg: arg1 + - arg: arg2 + - get: B + +- name: job2 + plan: + - aggregate: + - get: C + - get: A +`) + + err := yaml.Unmarshal(bs, &iface) + Expect(err).NotTo(HaveOccurred()) + pathfinder = yamlpatch.NewPathFinder(iface) + }) + + Describe("Find", func() { + DescribeTable( + "should", + func(path string, expected []string) { + actual := pathfinder.Find(path) + Expect(actual).To(HaveLen(len(expected))) + for _, el := range expected { + Expect(actual).To(ContainElement(el)) + } + }, + Entry("return a route for the root object", "/", []string{"/"}), + Entry("return a route for an object under the root", "/jobs", []string{"/jobs"}), + Entry("return a route for an element within an object under the root", "/jobs/0", []string{"/jobs/0"}), + Entry("return a route for an object within an element within an object under the root", "/jobs/0/plan", []string{"/jobs/0/plan"}), + Entry("return a route for an object within an element within an object under the root", "/jobs/0/plan/1", []string{"/jobs/0/plan/1"}), + Entry("return routes for multiple matches", "/jobs/get=A", []string{"/jobs/0/plan/0", "/jobs/1/plan/0/aggregate/1"}), + Entry("return a route for a single submatch with help", "/jobs/get=A/args/arg=arg2", []string{"/jobs/0/plan/0/args/1"}), + Entry("return a route for a single submatch with no help", "/jobs/get=A/arg=arg2", []string{"/jobs/0/plan/0/args/1"}), + ) + DescribeTable( + "should not", + func(path string) { + Expect(pathfinder.Find(path)).To(BeNil()) + }, + Entry("return any routes when given a bad index", "/jobs/2"), + Entry("return any routes when given a bad index", "/jobs/-1"), + Entry("return any routes when given a bad pointer", "/plan"), + ) + }) +})