diff --git a/assets/pango/example/main.go b/assets/pango/example/main.go index 500c6a54..bf623028 100644 --- a/assets/pango/example/main.go +++ b/assets/pango/example/main.go @@ -9,6 +9,7 @@ import ( "github.com/PaloAltoNetworks/pango" "github.com/PaloAltoNetworks/pango/device/services/dns" "github.com/PaloAltoNetworks/pango/device/services/ntp" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/network/interface/ethernet" "github.com/PaloAltoNetworks/pango/network/interface/loopback" "github.com/PaloAltoNetworks/pango/network/profiles/interface_management" @@ -23,7 +24,6 @@ import ( "github.com/PaloAltoNetworks/pango/panorama/template" "github.com/PaloAltoNetworks/pango/panorama/template_stack" "github.com/PaloAltoNetworks/pango/policies/rules/security" - "github.com/PaloAltoNetworks/pango/rule" "github.com/PaloAltoNetworks/pango/util" ) @@ -773,22 +773,11 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { log.Printf("Security policy rule '%s:%s' with description '%s' created", *securityPolicyRuleItemReply.Uuid, securityPolicyRuleItemReply.Name, *securityPolicyRuleItemReply.Description) } - rulePositionBefore7 := rule.Position{ - First: nil, - Last: nil, - SomewhereBefore: nil, - DirectlyBefore: util.String("codegen_rule7"), - SomewhereAfter: nil, - DirectlyAfter: nil, - } - rulePositionBottom := rule.Position{ - First: nil, - Last: util.Bool(true), - SomewhereBefore: nil, - DirectlyBefore: nil, - SomewhereAfter: nil, - DirectlyAfter: nil, + positionBefore7 := movement.PositionBefore{ + Directly: true, + Pivot: "codegen_rule7", } + positionLast := movement.PositionLast{} var securityPolicyRulesEntriesToMove []*security.Entry securityPolicyRulesEntriesToMove = append(securityPolicyRulesEntriesToMove, securityPolicyRulesEntries[3]) @@ -797,7 +786,7 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { for _, securityPolicyRuleItemToMove := range securityPolicyRulesEntriesToMove { log.Printf("Security policy rule '%s' is going to be moved", securityPolicyRuleItemToMove.Name) } - err := securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, rulePositionBefore7, securityPolicyRulesEntriesToMove) + err := securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, positionBefore7, securityPolicyRulesEntriesToMove) if err != nil { log.Printf("Failed to move security policy rules %v: %s", securityPolicyRulesEntriesToMove, err) return @@ -807,7 +796,7 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { for _, securityPolicyRuleItemToMove := range securityPolicyRulesEntriesToMove { log.Printf("Security policy rule '%s' is going to be moved", securityPolicyRuleItemToMove.Name) } - err = securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, rulePositionBottom, securityPolicyRulesEntriesToMove) + err = securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, positionLast, securityPolicyRulesEntriesToMove) if err != nil { log.Printf("Failed to move security policy rules %v: %s", securityPolicyRulesEntriesToMove, err) return diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go new file mode 100644 index 00000000..c097f43b --- /dev/null +++ b/assets/pango/movement/movement.go @@ -0,0 +1,579 @@ +package movement + +import ( + "errors" + "log/slog" + "slices" +) + +var _ = slog.LevelDebug + +type ActionWhereType string + +const ( + ActionWhereFirst ActionWhereType = "top" + ActionWhereLast ActionWhereType = "bottom" + ActionWhereBefore ActionWhereType = "before" + ActionWhereAfter ActionWhereType = "after" +) + +type Movable interface { + EntryName() string +} + +type MoveAction struct { + Movable Movable + Where ActionWhereType + Destination Movable +} + +type Position interface { + Move(entries []Movable, existing []Movable) ([]MoveAction, error) + GetExpected(entries []Movable, existing []Movable) ([]Movable, error) + IsDirectly() bool + Where() ActionWhereType + PivotEntryName() string +} + +type PositionFirst struct{} + +func (o PositionFirst) IsDirectly() bool { + return false +} + +func (o PositionFirst) Where() ActionWhereType { + return ActionWhereFirst +} + +func (o PositionFirst) PivotEntryName() string { + return "" +} + +type PositionLast struct{} + +func (o PositionLast) IsDirectly() bool { + return false +} + +func (o PositionLast) Where() ActionWhereType { + return ActionWhereLast +} + +func (o PositionLast) PivotEntryName() string { + return "" +} + +type PositionBefore struct { + Directly bool + Pivot string +} + +func (o PositionBefore) IsDirectly() bool { + return o.Directly +} + +func (o PositionBefore) Where() ActionWhereType { + return ActionWhereBefore +} + +func (o PositionBefore) PivotEntryName() string { + return o.Pivot +} + +type PositionAfter struct { + Directly bool + Pivot string +} + +func (o PositionAfter) IsDirectly() bool { + return o.Directly +} + +func (o PositionAfter) Where() ActionWhereType { + return ActionWhereAfter +} + +func (o PositionAfter) PivotEntryName() string { + return o.Pivot +} + +type entryWithIdx[E Movable] struct { + Entry E + Idx int +} + +func entriesByName[E Movable](entries []E) map[string]entryWithIdx[E] { + entriesIdxMap := make(map[string]entryWithIdx[E], len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt.EntryName()] = entryWithIdx[E]{ + Entry: elt, + Idx: idx, + } + } + return entriesIdxMap +} + +func removeEntriesFromExisting(entries []Movable, filterFn func(entry Movable) bool) []Movable { + entryNames := make(map[string]bool, len(entries)) + for _, elt := range entries { + entryNames[elt.EntryName()] = true + } + + filtered := make([]Movable, len(entries)) + copy(filtered, entries) + + filtered = slices.DeleteFunc(filtered, filterFn) + + return filtered +} + +func findPivotIdx(entries []Movable, pivot string) (int, Movable) { + var pivotEntry Movable + pivotIdx := slices.IndexFunc(entries, func(entry Movable) bool { + if entry.EntryName() == pivot { + pivotEntry = entry + return true + } + + return false + }) + + return pivotIdx, pivotEntry +} + +var ( + errNoMovements = errors.New("no movements needed") + ErrSlicesNotEqualLength = errors.New("existing and expected slices length mismatch") + ErrPivotInEntries = errors.New("pivot element found in the entries slice") + ErrPivotNotInExisting = errors.New("pivot element not foudn in the existing slice") + ErrInvalidMovementPlan = errors.New("created movement plan is invalid") +) + +// PositionBefore and PositionAfter are similar enough that we can generate expected sequences +// for both using the same code and some conditionals based on the given movement. +func getPivotMovement(entries []Movable, existing []Movable, pivot string, direct bool, movement ActionWhereType) ([]Movable, error) { + existingIdxMap := entriesByName(existing) + + entriesPivotIdx, _ := findPivotIdx(entries, pivot) + if entriesPivotIdx != -1 { + return nil, ErrPivotInEntries + } + + existingPivotIdx, _ := findPivotIdx(existing, pivot) + if existingPivotIdx == -1 { + return nil, ErrPivotNotInExisting + } + + if !direct { + movementRequired := false + entriesLen := len(entries) + loop: + for i := 0; i < entriesLen; i++ { + existingEntryIdx := existingIdxMap[entries[i].EntryName()].Idx + // For any given entry in the list of entries to move check if the entry + // index is at or after pivot point index, which will require movement + // set to be generated. + + // Then check if the entries to be moved have the same order in the existing + // slice, and if not require a movement set to be generated. + switch movement { + case ActionWhereBefore: + if existingEntryIdx >= existingPivotIdx { + movementRequired = true + break + } + + if i == 0 { + continue + } + + if existingIdxMap[entries[i-1].EntryName()].Idx >= existingEntryIdx { + movementRequired = true + break loop + + } + case ActionWhereAfter: + if existingEntryIdx <= existingPivotIdx { + movementRequired = true + break + } + + if i == len(entries)-1 { + continue + } + + if existingIdxMap[entries[i+1].EntryName()].Idx < existingEntryIdx { + movementRequired = true + break loop + + } + + } + } + + if !movementRequired { + return nil, errNoMovements + } + } + + expected := make([]Movable, len(existing)) + + entriesIdxMap := entriesByName(entries) + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry.EntryName()] + return ok + }) + + filteredPivotIdx, pivotEntry := findPivotIdx(filtered, pivot) + + switch movement { + case ActionWhereBefore: + expectedIdx := 0 + for ; expectedIdx < filteredPivotIdx; expectedIdx++ { + expected[expectedIdx] = filtered[expectedIdx] + } + + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + expected[expectedIdx] = pivotEntry + expectedIdx++ + + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + expectedIdx++ + } + + case ActionWhereAfter: + expectedIdx := 0 + for ; expectedIdx < filteredPivotIdx+1; expectedIdx++ { + expected[expectedIdx] = filtered[expectedIdx] + } + + if direct { + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + } + } else { + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + expectedIdx++ + } + + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + } + } + + return expected, nil +} + +func (o PositionAfter) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + return getPivotMovement(entries, existing, o.Pivot, o.Directly, ActionWhereAfter) +} + +func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + expected, err := o.GetExpected(entries, existing) + if err != nil { + if errors.Is(err, errNoMovements) { + return nil, nil + } + return nil, err + } + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereAfter, o.Pivot, o.Directly) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, entries, actions, o), nil +} + +func (o PositionBefore) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + return getPivotMovement(entries, existing, o.Pivot, o.Directly, ActionWhereBefore) +} + +func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + expected, err := o.GetExpected(entries, existing) + if err != nil { + if errors.Is(err, errNoMovements) { + return nil, nil + } + return nil, err + } + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereBefore, o.Pivot, o.Directly) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, entries, actions, o), nil +} + +type Entry struct { + Element Movable + Expected int + Existing int +} + +type sequencePosition struct { + Start int + End int +} + +func updateSimulatedIdxMap[E Movable](idxMap *map[string]entryWithIdx[E], moved Movable, startingIdx int, targetIdx int) { + for name, entry := range *idxMap { + if name == moved.EntryName() { + continue + } + + idx := entry.Idx + + if startingIdx > targetIdx && idx >= targetIdx { + entry.Idx = idx + 1 + (*idxMap)[name] = entry + } else if startingIdx < targetIdx && idx >= startingIdx && idx <= targetIdx { + entry.Idx = idx - 1 + (*idxMap)[name] = entry + } + } +} + +func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position Position) []MoveAction { + simulated := make([]Movable, len(existing)) + copy(simulated, existing) + simulatedIdxMap := entriesByName(simulated) + + var optimized []MoveAction + for _, action := range actions { + currentIdx := simulatedIdxMap[action.Movable.EntryName()].Idx + + var targetIdx int + switch action.Where { + case ActionWhereFirst: + targetIdx = 0 + case ActionWhereLast: + targetIdx = len(simulated) - 1 + case ActionWhereBefore: + targetIdx = simulatedIdxMap[action.Destination.EntryName()].Idx + case ActionWhereAfter: + targetIdx = simulatedIdxMap[action.Destination.EntryName()].Idx + 1 + } + + slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) + + if targetIdx != currentIdx { + optimized = append(optimized, action) + entry := simulatedIdxMap[action.Movable.EntryName()] + entry.Idx = targetIdx + simulatedIdxMap[action.Movable.EntryName()] = entry + updateSimulatedIdxMap(&simulatedIdxMap, action.Movable, currentIdx, targetIdx) + } + } + + slog.Debug("OptimizeMovements()", "optimized", optimized) + + return optimized +} + +func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement ActionWhereType, pivot string, directly bool) ([]MoveAction, error) { + if len(existing) != len(expected) { + slog.Error("GenerateMovements()", "len(existing)", len(existing), "len(expected)", len(expected)) + return nil, ErrSlicesNotEqualLength + } + + entriesIdxMap := entriesByName(entries) + existingIdxMap := entriesByName(existing) + expectedIdxMap := entriesByName(expected) + + var movements []MoveAction + var previous Movable + for _, elt := range entries { + slog.Debug("GeneraveMovements()", "elt", elt, "existing", existingIdxMap[elt.EntryName()], "expected", expectedIdxMap[elt.EntryName()]) + + if previous != nil { + movements = append(movements, MoveAction{ + Movable: elt, + Destination: previous, + Where: ActionWhereAfter, + }) + previous = elt + continue + } + if expectedIdxMap[elt.EntryName()].Idx == 0 { + movements = append(movements, MoveAction{ + Movable: elt, + Destination: nil, + Where: ActionWhereFirst, + }) + previous = elt + } else if expectedIdxMap[elt.EntryName()].Idx == len(expectedIdxMap)-1 { + movements = append(movements, MoveAction{ + Movable: elt, + Destination: nil, + Where: ActionWhereLast, + }) + previous = elt + } else { + var where ActionWhereType + + var pivot Movable + switch movement { + case ActionWhereLast: + where = ActionWhereLast + case ActionWhereAfter: + pivot = expected[expectedIdxMap[elt.EntryName()].Idx-1] + where = ActionWhereAfter + case ActionWhereFirst: + pivot = existing[0] + where = ActionWhereBefore + case ActionWhereBefore: + eltExpectedIdx := expectedIdxMap[elt.EntryName()].Idx + pivot = expected[eltExpectedIdx+1] + where = ActionWhereBefore + // When entries are to be put directly before the pivot point, if previous was nil (we + // are processing the first element in entries set) and selected pivot is part of the + // entries set, we need to find the actual pivot, i.e. element of the expected list + // that directly follows all elements from the entries set. + if _, ok := entriesIdxMap[pivot.EntryName()]; ok && directly { + // The actual pivot for the move is the element that follows all elements + // from the existing set. + pivotIdx := eltExpectedIdx + len(entries) + if pivotIdx >= len(expected) { + // This should never happen as by definition there is at least + // element (pivot point) at the end of the expected slice. + return nil, ErrInvalidMovementPlan + } + pivot = expected[pivotIdx] + + } + } + + movements = append(movements, MoveAction{ + Movable: elt, + Destination: pivot, + Where: where, + }) + previous = elt + } + + } + + slog.Debug("GenerateMovements()", "movements", movements) + + return movements, nil +} + +func (o PositionFirst) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + entriesIdxMap := entriesByName(entries) + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry.EntryName()] + return ok + }) + + expected := append(entries, filtered...) + + return expected, nil +} + +func (o PositionFirst) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + expected, err := o.GetExpected(entries, existing) + if err != nil { + return nil, err + } + + slog.Error("PositionFirst.Move()", "len(expected)", len(expected), "len(existing)", len(existing)) + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereFirst, "", false) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, entries, actions, o), nil +} + +func (o PositionLast) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + entriesIdxMap := entriesByName(entries) + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry.EntryName()] + return ok + }) + + expected := append(filtered, entries...) + + return expected, nil +} + +func (o PositionLast) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + expected, err := o.GetExpected(entries, existing) + if err != nil { + return nil, err + } + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereLast, "", false) + if err != nil { + slog.Debug("PositionLast()", "err", err) + return nil, err + } + return OptimizeMovements(existing, expected, entries, actions, o), nil +} + +type Movement struct { + Entries []Movable + Position Position +} + +func MoveGroups[E Movable](existing []Movable, movements []Movement) ([]MoveAction, error) { + expected := existing + for idx := range len(movements) - 1 { + position := movements[idx].Position + entries := movements[idx].Entries + result, err := position.GetExpected(entries, expected) + if err != nil { + if !errors.Is(err, errNoMovements) { + return nil, err + } + continue + } + expected = result + } + + entries := movements[len(movements)-1].Entries + position := movements[len(movements)-1].Position + return position.Move(entries, expected) +} + +func MoveGroup[E Movable](position Position, entries []E, existing []E) ([]MoveAction, error) { + var movableEntries []Movable + for _, elt := range entries { + slog.Warn("MoveGroup", "entry.EntryName()", elt.EntryName()) + movableEntries = append(movableEntries, elt) + } + var movableExisting []Movable + for _, elt := range existing { + slog.Warn("MoveGroup", "existing.EntryName()", elt.EntryName()) + movableExisting = append(movableExisting, elt) + } + return position.Move(movableEntries, movableExisting) +} + +type Move struct { + Position Position + Existing []Movable +} diff --git a/assets/pango/movement/movement_suite_test.go b/assets/pango/movement/movement_suite_test.go new file mode 100644 index 00000000..b750b000 --- /dev/null +++ b/assets/pango/movement/movement_suite_test.go @@ -0,0 +1,18 @@ +package movement_test + +import ( + "log/slog" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMovement(t *testing.T) { + handler := slog.NewTextHandler(GinkgoWriter, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slog.SetDefault(slog.New(handler)) + RegisterFailHandler(Fail) + RunSpecs(t, "Movement Suite") +} diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go new file mode 100644 index 00000000..9e5e9026 --- /dev/null +++ b/assets/pango/movement/movement_test.go @@ -0,0 +1,411 @@ +package movement_test + +import ( + "fmt" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "movements/movement" +) + +var _ = fmt.Printf + +type Mock struct { + Name string +} + +func (o Mock) EntryName() string { + return o.Name +} + +func asMovable(mocks []string) []movement.Movable { + var movables []movement.Movable + + for _, elt := range mocks { + movables = append(movables, Mock{elt}) + } + + return movables +} + +var _ = Describe("MoveGroup()", func() { + Context("With PositionFirst used as position", func() { + Context("when existing positions matches expected", func() { + It("should generate no movements", func() { + // '(A B C) -> '(A B C) + expected := asMovable([]string{"A", "B", "C"}) + moves, err := movement.MoveGroup(movement.PositionFirst{}, expected, expected) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("when it has to move two elements", func() { + It("should generate three move actions", func() { + // '(D E A B C) -> '(A B C D E) + entries := asMovable([]string{"A", "B", "C"}) + existing := asMovable([]string{"D", "E", "A", "B", "C"}) + + moves, err := movement.MoveGroup(movement.PositionFirst{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(3)) + + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereFirst)) + Expect(moves[0].Destination).To(BeNil()) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) + + Expect(moves[2].Movable.EntryName()).To(Equal("C")) + Expect(moves[2].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[2].Destination.EntryName()).To(Equal("B")) + }) + }) + Context("when expected order is reversed", func() { + It("should generate required move actions to converge lists", func() { + // '(A B C D E) -> '(E D C B A) + entries := asMovable([]string{"E", "D", "C", "B", "A"}) + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + moves, err := movement.MoveGroup(movement.PositionFirst{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + + // '((E 'top nil)(B 'after E)(C 'after B)(D 'after C)) + // 'A element stays in place + Expect(moves).To(HaveLen(4)) + }) + }) + }) + Context("With PositionLast used as position", func() { + Context("with non-consecutive entries", func() { + It("should generate two move actions", func() { + // '(A E B C D) -> '(A B D E C) + entries := asMovable([]string{"E", "C"}) + existing := asMovable([]string{"A", "E", "B", "C", "D"}) + + moves, err := movement.MoveGroup(movement.PositionLast{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereLast)) + Expect(moves[0].Destination).To(BeNil()) + + Expect(moves[1].Movable.EntryName()).To(Equal("C")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("E")) + }) + }) + }) + Context("With PositionLast used as position", func() { + Context("when it needs to move one element", func() { + It("should generate a single move action", func() { + // '(A E B C D) -> '(A B C D E) + entries := asMovable([]string{"E"}) + existing := asMovable([]string{"A", "E", "B", "C", "D"}) + + moves, err := movement.MoveGroup(movement.PositionLast{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereLast)) + Expect(moves[0].Destination).To(BeNil()) + }) + }) + }) + + Context("With PositionAfter used as position", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when direct position relative to the pivot is not required", func() { + It("should not generate any move actions", func() { + // '(A B C D E) -> '(A B C D E) + entries := asMovable([]string{"D", "E"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: false, Pivot: "B"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A B C E D) + entries := asMovable([]string{"E", "D"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: false, Pivot: "B"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("C")) + }) + }) + }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // '(A B C D E) -> '(C D A B E) + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: true, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) + }) + }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // '(A B C D E) -> '(C D B A E) + entries := asMovable([]string{"B", "A"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: true, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("B")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) + + Expect(moves[1].Movable.EntryName()).To(Equal("A")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("B")) + }) + }) + }) + + // '(A E B C D) -> '(A B C D E) => '(E 'bottom nil) / '(E 'after D) + + // PositionSomewhereBefore PositionDirectlyBefore + // '(C B 'before E, directly) + // '(A B C D E) -> '(A D C B E) -> '(B 'before E) + // '(A B C D E) -> '(A C B D E) -> '(B 'after C) + + Context("With PositionBefore used as position", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when doing a direct move with entries reordering", func() { + It("should put reordered entries directly before pivot point", func() { + // '(A B C D E) -> '(A D C B E) + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: "E"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("E")) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("C")) + }) + }) + Context("when doing a non direct move with entries reordering", func() { + It("should reorder entries in-place without moving them around", func() { + // '(A B C D E) -> '(A C B D E) + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: "E"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("B")) + }) + }) + Context("when direct position relative to the pivot is not required", func() { + Context("and moved entries are already before pivot point", func() { + It("should not generate any move actions", func() { + // '(A B C D E) -> '(A B C D E) + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A C B D E) + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("B")) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A B C D E) + entries := asMovable([]string{"A", "C"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A C B D E) + entries := asMovable([]string{"A", "C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("A")) + }) + }) + }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // '(A B C D E) -> '(C A B D E) + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: "D"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) + }) + }) + Context("when passing single Movement to MoveGroups()", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + It("should return a set of move actions that describe it", func() { + // '(A B C D E) -> '(A D B C E) + entries := asMovable([]string{"B", "C"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: "E"}, + entries, existing) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + }) + }) + }) +}) + +var _ = Describe("MoveGroups()", Label("MoveGroups"), func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when passing single Movement to MoveGroups()", func() { + It("should return a set of move actions that describe it", func() { + // '(A B C D E) -> '(A D B C E) + entries := asMovable([]string{"B", "C"}) + movements := []movement.Movement{{ + Entries: entries, + Position: movement.PositionBefore{ + Directly: true, + Pivot: "E", + }}} + moves, err := movement.MoveGroups[Mock](existing, movements) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + }) + }) + // Context("when passing single Movement to MoveGroups()", func() { + // FIt("should return a set of move actions that describe it", func() { + // // '(A B C D E) -> '(A D B C E) -> '(D B C E A) + // movements := []movement.Movement{ + // { + // Entries: asMovable([]string{"B", "C"}), + // Position: movement.PositionBefore{ + // Directly: true, + // Pivot: Mock{"E"}}, + // }, + // { + // Entries: asMovable([]string{"A"}), + // Position: movement.PositionLast{}, + // }, + // } + // moves, err := movement.MoveGroups(existing, movements) + + // Expect(err).ToNot(HaveOccurred()) + // Expect(moves).To(HaveLen(3)) + // }) + // }) +}) + +var _ = Describe("Movement benchmarks", func() { + BeforeEach(func() { + if !Label("benchmark").MatchesLabelFilter(GinkgoLabelFilter()) { + Skip("unless label 'benchmark' is specified.") + } + }) + Context("when moving only a few elements", func() { + It("should generate a simple sequence of actions", Label("benchmark"), func() { + var elts []string + elements := 50000 + for idx := range elements { + elts = append(elts, fmt.Sprintf("%d", idx)) + } + existing := asMovable(elts) + + entries := asMovable([]string{"90", "80", "70", "60", "50", "40"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: "100"}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(6)) + + Expect(moves[0].Movable.EntryName()).To(Equal("90")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("100")) + }) + }) +}) diff --git a/assets/terraform/internal/manager/manager.go b/assets/terraform/internal/manager/manager.go index 4e18e952..e598ec1b 100644 --- a/assets/terraform/internal/manager/manager.go +++ b/assets/terraform/internal/manager/manager.go @@ -30,6 +30,7 @@ func (o *Error) Unwrap() error { } var ( + ErrPlanConflict = errors.New("multiple plan entries with shared name") ErrConflict = errors.New("entry from the plan already exists on the server") ErrMissingUuid = errors.New("entry is missing required uuid") ErrMarshaling = errors.New("failed to marshal entry to XML document") diff --git a/assets/terraform/internal/manager/uuid.go b/assets/terraform/internal/manager/uuid.go index 630cd5dc..2119f163 100644 --- a/assets/terraform/internal/manager/uuid.go +++ b/assets/terraform/internal/manager/uuid.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" sdkerrors "github.com/PaloAltoNetworks/pango/errors" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/util" "github.com/PaloAltoNetworks/pango/version" "github.com/PaloAltoNetworks/pango/xmlapi" @@ -36,7 +36,7 @@ type SDKUuidService[E UuidObject, L UuidLocation] interface { Create(context.Context, L, E) (E, error) List(context.Context, L, string, string, string) ([]E, error) Delete(context.Context, L, ...string) error - MoveGroup(context.Context, L, rule.Position, []E) error + MoveGroup(context.Context, L, movement.Position, []E) error } type uuidObjectWithState[E EntryObject] struct { @@ -156,7 +156,7 @@ func (o *UuidObjectManager[E, L, S]) entriesProperlySorted(existing []E, planEnt return movementRequired, nil } -func (o *UuidObjectManager[E, L, S]) moveExhaustive(ctx context.Context, location L, entriesByName map[string]uuidObjectWithState[E], position rule.Position) error { +func (o *UuidObjectManager[E, L, S]) moveExhaustive(ctx context.Context, location L, entriesByName map[string]uuidObjectWithState[E], position movement.Position) error { existing, err := o.service.List(ctx, location, "get", "", "") if err != nil && err.Error() != "Object not found" { return &Error{err: err, message: "Failed to list existing entries"} @@ -202,87 +202,27 @@ type position struct { // When moveNonExhaustive is called, the given list is not entirely managed by the Terraform resource. // In that case a care has to be taken to only execute movement on a subset of entries, those that // are under Terraform control. -func (o *UuidObjectManager[E, L, S]) moveNonExhaustive(ctx context.Context, location L, planEntries []E, planEntriesByName map[string]uuidObjectWithState[E], sdkPosition rule.Position) error { - - existing, err := o.service.List(ctx, location, "get", "", "") - if err != nil { - return fmt.Errorf("failed to list remote entries: %w", err) - } - - movementRequired, err := o.entriesProperlySorted(existing, planEntriesByName) - - // If all entries are ordered properly, check if their position matches the requested - // position. - if !movementRequired { - existingEntriesByName := o.entriesByName(existing, entryOk) - p, err := parseSDKPosition(sdkPosition) - if err != nil { - return ErrInvalidPosition - } - - switch p.Where { - case PositionWhereFirst: - planEntryName := planEntries[0].EntryName() - movementRequired = existing[0].EntryName() != planEntryName - case PositionWhereLast: - planEntryName := planEntries[len(planEntries)-1].EntryName() - movementRequired = existing[len(existing)-1].EntryName() != planEntryName - case PositionWhereBefore: - lastPlanElementName := planEntries[len(planEntries)-1].EntryName() - if existingPivot, found := existingEntriesByName[p.PivotEntry]; !found { - return ErrMissingPivotPoint - } else if p.Directly { - if existingPivot.StateIdx == 0 { - movementRequired = true - } else if existing[existingPivot.StateIdx-1].EntryName() != lastPlanElementName { - movementRequired = true - } - } else { - if lastPlanElementInExisting, found := existingEntriesByName[lastPlanElementName]; !found { - return ErrMissingPivotPoint - } else if lastPlanElementInExisting.StateIdx >= existingPivot.StateIdx { - movementRequired = true - } - } - case PositionWhereAfter: - firstPlanElementName := planEntries[0].EntryName() - if existingPivot, found := existingEntriesByName[p.PivotEntry]; !found { - return ErrMissingPivotPoint - } else if p.Directly { - if existingPivot.StateIdx == len(existing)-1 { - movementRequired = true - } else if existing[existingPivot.StateIdx+1].EntryName() != firstPlanElementName { - movementRequired = true - } - } else { - if firstPlanElementInExisting, found := existingEntriesByName[firstPlanElementName]; !found { - return ErrMissingPivotPoint - } else if firstPlanElementInExisting.StateIdx <= existingPivot.StateIdx { - movementRequired = true - } - } - } +func (o *UuidObjectManager[E, L, S]) moveNonExhaustive(ctx context.Context, location L, planEntries []E, planEntriesByName map[string]uuidObjectWithState[E], sdkPosition movement.Position) error { + entries := make([]E, len(planEntriesByName)) + for _, elt := range planEntriesByName { + entries[elt.StateIdx] = elt.Entry } - if movementRequired { - entries := make([]E, len(planEntriesByName)) - for _, elt := range planEntriesByName { - entries[elt.StateIdx] = elt.Entry - } - - err = o.service.MoveGroup(ctx, location, sdkPosition, entries) - if err != nil { - return &Error{err: err, message: "Failed to move group of entries"} - } + err := o.service.MoveGroup(ctx, location, sdkPosition, entries) + if err != nil { + return &Error{err: err, message: "Failed to move group of entries"} } return nil } -func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, planEntries []E, exhaustive ExhaustiveType, sdkPosition rule.Position) ([]E, error) { +func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, planEntries []E, exhaustive ExhaustiveType, sdkPosition movement.Position) ([]E, error) { var diags diag.Diagnostics planEntriesByName := o.entriesByName(planEntries, entryUnknown) + if len(planEntriesByName) != len(planEntries) { + return nil, ErrPlanConflict + } existing, err := o.service.List(ctx, location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { @@ -367,9 +307,12 @@ func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, return entries, nil } -func (o *UuidObjectManager[E, L, S]) UpdateMany(ctx context.Context, location L, stateEntries []E, planEntries []E, exhaustive ExhaustiveType, position rule.Position) ([]E, error) { +func (o *UuidObjectManager[E, L, S]) UpdateMany(ctx context.Context, location L, stateEntries []E, planEntries []E, exhaustive ExhaustiveType, position movement.Position) ([]E, error) { stateEntriesByName := o.entriesByName(stateEntries, entryUnknown) planEntriesByName := o.entriesByName(planEntries, entryUnknown) + if len(planEntriesByName) != len(planEntries) { + return nil, ErrPlanConflict + } findMatchingStateEntry := func(entry E) (E, bool) { var found bool @@ -685,45 +628,3 @@ func (o *UuidObjectManager[E, L, S]) Delete(ctx context.Context, location L, ent } return nil } - -func parseSDKPosition(sdkPosition rule.Position) (position, error) { - if sdkPosition.IsValid(false) != nil { - return position{}, ErrInvalidPosition - } - - if sdkPosition.DirectlyAfter != nil { - return position{ - Directly: true, - Where: PositionWhereAfter, - PivotEntry: *sdkPosition.DirectlyAfter, - }, nil - } else if sdkPosition.DirectlyBefore != nil { - return position{ - Directly: true, - Where: PositionWhereBefore, - PivotEntry: *sdkPosition.DirectlyBefore, - }, nil - } else if sdkPosition.SomewhereAfter != nil { - return position{ - Directly: false, - Where: PositionWhereAfter, - PivotEntry: *sdkPosition.SomewhereAfter, - }, nil - } else if sdkPosition.SomewhereBefore != nil { - return position{ - Directly: false, - Where: PositionWhereBefore, - PivotEntry: *sdkPosition.SomewhereBefore, - }, nil - } else if sdkPosition.First != nil { - return position{ - Where: PositionWhereFirst, - }, nil - } else if sdkPosition.Last != nil { - return position{ - Where: PositionWhereLast, - }, nil - } - - return position{}, ErrInvalidPosition -} diff --git a/assets/terraform/internal/manager/uuid_test.go b/assets/terraform/internal/manager/uuid_test.go index 7e3d07f5..7fbce77c 100644 --- a/assets/terraform/internal/manager/uuid_test.go +++ b/assets/terraform/internal/manager/uuid_test.go @@ -2,15 +2,17 @@ package manager_test import ( "context" + "log" "log/slog" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" sdkmanager "github.com/PaloAltoNetworks/terraform-provider-panos/internal/manager" ) +var _ = log.Printf var _ = Expect var _ = slog.Debug @@ -32,18 +34,16 @@ var _ = Describe("Server", func() { var client *MockUuidClient[*MockUuidObject] var service sdkmanager.SDKUuidService[*MockUuidObject, MockLocation] var mockService *MockUuidService[*MockUuidObject, MockLocation] - var trueVal bool var location MockLocation var ctx context.Context - var position rule.Position + var position movement.Position var entries []*MockUuidObject var mode sdkmanager.ExhaustiveType BeforeEach(func() { location = MockLocation{} ctx = context.Background() - trueVal = true initial = []*MockUuidObject{{Name: "1", Value: "A"}, {Name: "2", Value: "B"}, {Name: "3", Value: "C"}} client = NewMockUuidClient(initial) service = NewMockUuidService[*MockUuidObject, MockLocation](client) @@ -65,7 +65,7 @@ var _ = Describe("Server", func() { It("CreateMany() should create new entries on the server, and return them with uuid set", func() { entries := []*MockUuidObject{{Name: "1", Value: "A"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, rule.Position{First: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(1)) @@ -100,7 +100,7 @@ var _ = Describe("Server", func() { Context("and all entries being created are new to the server", func() { It("should create those entries in the correct position", func() { - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{First: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(2)) @@ -117,7 +117,7 @@ var _ = Describe("Server", func() { BeforeEach(func() { entries = []*MockUuidObject{{Name: "1", Value: "A'"}, {Name: "3", Value: "C"}} mode = sdkmanager.Exhaustive - position = rule.Position{First: &trueVal} + position = movement.PositionFirst{} }) It("should not return any error and overwrite all entries on the server", func() { @@ -169,7 +169,7 @@ var _ = Describe("Server", func() { Expect(processed).To(HaveLen(3)) Expect(processed).NotTo(MatchEntries(entries)) - processed, err = manager.UpdateMany(ctx, location, entries, entries, sdkmanager.NonExhaustive, rule.Position{First: &trueVal}) + processed, err = manager.UpdateMany(ctx, location, entries, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -180,11 +180,34 @@ var _ = Describe("Server", func() { Context("initially has some entries", func() { Context("when creating new entries with NonExhaustive type", func() { - Context("and position is set to Last", func() { + Context("and position is set to first", func() { + It("should create new entries on the top of the list", func() { + entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} + + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) + Expect(err).ToNot(HaveOccurred()) + Expect(processed).To(HaveLen(3)) + + Expect(processed[0]).To(Equal(entries[0])) + Expect(processed[1]).To(Equal(entries[1])) + Expect(processed[2]).To(Equal(entries[2])) + + clientEntries := client.list() + Expect(clientEntries).To(HaveLen(6)) + + Expect(mockService.moveGroupEntries).To(Equal(entries)) + + Expect(clientEntries[0]).To(Equal(entries[0])) + Expect(clientEntries[1]).To(Equal(entries[1])) + Expect(clientEntries[2]).To(Equal(entries[2])) + + }) + }) + Context("and position is set to last", func() { It("should create new entries on the bottom of the list", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{Last: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionLast{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -195,18 +218,19 @@ var _ = Describe("Server", func() { clientEntries := client.list() Expect(clientEntries).To(HaveLen(6)) + Expect(mockService.moveGroupEntries).To(Equal(entries)) + Expect(clientEntries[3]).To(Equal(entries[0])) Expect(clientEntries[4]).To(Equal(entries[1])) Expect(clientEntries[5]).To(Equal(entries[2])) - Expect(mockService.moveGroupEntries).To(Equal(entries)) }) }) Context("and position is set to directly after first element", func() { It("should create new entries directly after first existing element", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{DirectlyAfter: &initial[0].Name}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionAfter{Directly: true, Pivot: initial[0].Name}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -234,7 +258,7 @@ var _ = Describe("Server", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} pivot := initial[2].Name // "3" - position = rule.Position{DirectlyBefore: &pivot} + position = movement.PositionBefore{Directly: true, Pivot: pivot} processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, position) Expect(err).ToNot(HaveOccurred()) @@ -251,6 +275,14 @@ var _ = Describe("Server", func() { Expect(mockService.moveGroupEntries).To(Equal(entries)) }) }) + Context("and there is a duplicate entry within a list", func() { + It("should properly raise an error", func() { + entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "4", Value: "D"}} + _, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) + + Expect(err).To(MatchError(sdkmanager.ErrPlanConflict)) + }) + }) }) }) }) diff --git a/assets/terraform/internal/manager/uuid_utils_test.go b/assets/terraform/internal/manager/uuid_utils_test.go index de51b87a..df393bef 100644 --- a/assets/terraform/internal/manager/uuid_utils_test.go +++ b/assets/terraform/internal/manager/uuid_utils_test.go @@ -8,7 +8,7 @@ import ( "net/http" "net/url" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/version" "github.com/PaloAltoNetworks/pango/xmlapi" @@ -180,7 +180,7 @@ func (o *MockUuidService[E, L]) removeEntriesFromCurrent(entries []*MockUuidObje return firstIdx } -func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLocation, position rule.Position, entries []*MockUuidObject) error { +func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLocation, position movement.Position, entries []*MockUuidObject) error { o.moveGroupEntries = entries firstIdx := o.removeEntriesFromCurrent(entries) @@ -190,34 +190,30 @@ func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLoca entriesList.PushBack(elt) } - if position.First != nil { + switch position.(type) { + case movement.PositionFirst: o.client.Current.PushFrontList(entriesList) return nil - } else if position.Last != nil { + case movement.PositionLast: o.client.Current.PushBackList(entriesList) return nil + case movement.PositionBefore, movement.PositionAfter: } var pivotEntry string var after bool var directly bool - if position.DirectlyBefore != nil { - pivotEntry = *position.DirectlyBefore + switch typed := position.(type) { + case movement.PositionBefore: after = false - directly = true - } else if position.DirectlyAfter != nil { - pivotEntry = *position.DirectlyAfter + directly = typed.Directly + pivotEntry = typed.Pivot + case movement.PositionAfter: after = true - directly = true - } else if position.SomewhereBefore != nil { - pivotEntry = *position.SomewhereBefore - after = false - directly = false - } else if position.SomewhereAfter != nil { - pivotEntry = *position.SomewhereAfter - after = true - directly = false + directly = typed.Directly + pivotEntry = typed.Pivot + case movement.PositionFirst, movement.PositionLast: } var pivotElt *list.Element diff --git a/assets/terraform/internal/provider/position.go b/assets/terraform/internal/provider/position.go index c1b128c1..deafc80f 100644 --- a/assets/terraform/internal/provider/position.go +++ b/assets/terraform/internal/provider/position.go @@ -8,7 +8,7 @@ import ( rsschema "github.com/hashicorp/terraform-plugin-framework/resource/schema" "github.com/hashicorp/terraform-plugin-framework/types" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" ) type TerraformPositionObject struct { @@ -34,36 +34,21 @@ func TerraformPositionObjectSchema() rsschema.SingleNestedAttribute { } } -func (o *TerraformPositionObject) CopyToPango() rule.Position { - trueVal := true +func (o *TerraformPositionObject) CopyToPango() movement.Position { switch o.Where.ValueString() { case "first": - return rule.Position{ - First: &trueVal, - } + return movement.PositionFirst{} case "last": - return rule.Position{ - Last: &trueVal, - } + return movement.PositionLast{} case "before": - if o.Directly.ValueBool() == true { - return rule.Position{ - DirectlyBefore: o.Pivot.ValueStringPointer(), - } - } else { - return rule.Position{ - SomewhereBefore: o.Pivot.ValueStringPointer(), - } + return movement.PositionBefore{ + Pivot: o.Pivot.ValueString(), + Directly: o.Directly.ValueBool(), } case "after": - if o.Directly.ValueBool() == true { - return rule.Position{ - DirectlyAfter: o.Pivot.ValueStringPointer(), - } - } else { - return rule.Position{ - SomewhereAfter: o.Pivot.ValueStringPointer(), - } + return movement.PositionAfter{ + Pivot: o.Pivot.ValueString(), + Directly: o.Directly.ValueBool(), } default: panic("unreachable") diff --git a/assets/terraform/main.go b/assets/terraform/main.go index a1540c1a..368bc3da 100644 --- a/assets/terraform/main.go +++ b/assets/terraform/main.go @@ -4,12 +4,16 @@ import ( "context" "flag" "log" + "os" + "runtime/pprof" "github.com/PaloAltoNetworks/terraform-provider-panos/internal/provider" "github.com/hashicorp/terraform-plugin-framework/providerserver" ) +var _ = pprof.StartCPUProfile + // Run "go generate" to format example terraform files and generate the docs for the registry/website // If you do not have terraform installed, you can remove the formatting command, but its suggested to @@ -35,6 +39,16 @@ func main() { flag.BoolVar(&debug, "debug", false, "set to true to run the provider with support for debuggers like delve") flag.Parse() + cpuprofile := os.Getenv("TF_PANOS_PROFILE") + if cpuprofile != "" { + f, err := os.Create(cpuprofile) + if err != nil { + log.Fatal(err) + } + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + opts := providerserver.ServeOpts{ Address: "registry.terraform.io/paloaltonetworks/panos", Debug: debug, diff --git a/assets/terraform/test/resource_nat_policy_test.go b/assets/terraform/test/resource_nat_policy_test.go index a29d61c2..b9e860b5 100644 --- a/assets/terraform/test/resource_nat_policy_test.go +++ b/assets/terraform/test/resource_nat_policy_test.go @@ -40,9 +40,12 @@ type expectServerNatRulesOrder struct { RuleNames []string } -func ExpectServerNatRulesOrder(prefix string, location nat.Location, ruleNames []string) *expectServerNatRulesOrder { +func ExpectServerNatRulesOrder(prefix string, ruleNames []string) *expectServerNatRulesOrder { + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerNatRulesOrder{ - Location: location, + Location: *location, Prefix: prefix, RuleNames: ruleNames, } @@ -111,10 +114,13 @@ type expectServerNatRulesCount struct { Count int } -func ExpectServerNatRulesCount(prefix string, location nat.Location, count int) *expectServerNatRulesCount { +func ExpectServerNatRulesCount(prefix string, count int) *expectServerNatRulesCount { + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerNatRulesCount{ Prefix: prefix, - Location: location, + Location: *location, Count: count, } } @@ -143,10 +149,23 @@ func (o *expectServerNatRulesCount) CheckState(ctx context.Context, req stateche const natPolicyExtendedResource1Tmpl = ` variable "prefix" { type = string } -variable "location" { type = map } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} resource "panos_nat_policy" "policy" { - location = var.location + location = { device_group = { name = resource.panos_device_group.dg.name }} rules = [{ name = format("%s-rule1", var.prefix) @@ -331,19 +350,15 @@ func TestAccNatPolicyExtended(t *testing.T) { nameSuffix := acctest.RandStringFromCharSet(6, acctest.CharSetAlphaNum) prefix := fmt.Sprintf("test-acc-%s", nameSuffix) - device := devicePanorama - sdkLocation, cfgLocation := natPolicyLocationByDeviceType(device, "post-rulebase") - resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource1Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -433,13 +448,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource2Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -517,13 +531,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource3Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -565,13 +578,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource4Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -645,7 +657,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { device := devicePanorama - sdkLocation, cfgLocation := natPolicyLocationByDeviceType(device, "pre-rulebase") + sdkLocation, _ := natPolicyLocationByDeviceType(device, "pre-rulebase") stateExpectedRuleName := func(idx int, value string) statecheck.StateCheck { return statecheck.ExpectKnownValue( @@ -670,27 +682,27 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ stateExpectedRuleName(0, "rule-1"), stateExpectedRuleName(1, "rule-2"), stateExpectedRuleName(2, "rule-3"), - ExpectServerNatRulesCount(prefix, sdkLocation, len(rulesInitial)), - ExpectServerNatRulesOrder(prefix, sdkLocation, rulesInitial), + ExpectServerNatRulesCount(prefix, len(rulesInitial)), + ExpectServerNatRulesOrder(prefix, rulesInitial), }, }, { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -702,7 +714,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesReordered)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -715,7 +727,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { stateExpectedRuleName(0, "rule-2"), stateExpectedRuleName(1, "rule-1"), stateExpectedRuleName(2, "rule-3"), - ExpectServerNatRulesOrder(prefix, sdkLocation, rulesReordered), + ExpectServerNatRulesOrder(prefix, rulesReordered), }, }, }, @@ -723,11 +735,24 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { } const configTmpl = ` +variable "prefix" { type = string } variable "rule_names" { type = list(string) } -variable "location" { type = map } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} resource "panos_nat_policy" "{{ .ResourceName }}" { - location = var.location + location = { device_group = { name = resource.panos_device_group.dg.name }} rules = [ for index, name in var.rule_names: { @@ -830,12 +855,15 @@ func natPolicyPreCheck(prefix string, location nat.Location) { } } -func natPolicyCheckDestroy(prefix string, location nat.Location) func(s *terraform.State) error { +func natPolicyCheckDestroy(prefix string) func(s *terraform.State) error { return func(s *terraform.State) error { service := nat.NewService(sdkClient) ctx := context.TODO() - rules, err := service.List(ctx, location, "get", "", "") + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + + rules, err := service.List(ctx, *location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { return err } @@ -849,7 +877,7 @@ func natPolicyCheckDestroy(prefix string, location nat.Location) func(s *terrafo if len(danglingNames) > 0 { err := DanglingObjectsError - delErr := service.Delete(ctx, location, danglingNames...) + delErr := service.Delete(ctx, *location, danglingNames...) if delErr != nil { err = errors.Join(err, delErr) } diff --git a/assets/terraform/test/resource_security_policy_test.go b/assets/terraform/test/resource_security_policy_test.go index a070ffa2..844d7854 100644 --- a/assets/terraform/test/resource_security_policy_test.go +++ b/assets/terraform/test/resource_security_policy_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "regexp" "strings" "testing" @@ -26,9 +27,12 @@ type expectServerSecurityRulesOrder struct { RuleNames []string } -func ExpectServerSecurityRulesOrder(prefix string, location security.Location, ruleNames []string) *expectServerSecurityRulesOrder { +func ExpectServerSecurityRulesOrder(prefix string, ruleNames []string) *expectServerSecurityRulesOrder { + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerSecurityRulesOrder{ - Location: location, + Location: *location, Prefix: prefix, RuleNames: ruleNames, } @@ -97,10 +101,12 @@ type expectServerSecurityRulesCount struct { Count int } -func ExpectServerSecurityRulesCount(prefix string, location security.Location, count int) *expectServerSecurityRulesCount { +func ExpectServerSecurityRulesCount(prefix string, count int) *expectServerSecurityRulesCount { + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) return &expectServerSecurityRulesCount{ Prefix: prefix, - Location: location, + Location: *location, Count: count, } } @@ -127,19 +133,60 @@ func (o *expectServerSecurityRulesCount) CheckState(ctx context.Context, req sta } } +const securityPolicyDuplicatedTmpl = ` +variable "prefix" { type = string } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} + + +resource "panos_security_policy" "policy" { + location = { device_group = { name = resource.panos_device_group.dg.name }} + + rules = [ + { + name = format("%s-rule", var.prefix) + source_zones = ["any"] + source_addresses = ["any"] + + destination_zones = ["any"] + destination_addresses = ["any"] + }, + { + name = format("%s-rule", var.prefix) + source_zones = ["any"] + source_addresses = ["any"] + + destination_zones = ["any"] + destination_addresses = ["any"] + } + ] +} +` + const securityPolicyExtendedResource1Tmpl = ` variable "prefix" { type = string } resource "panos_template" "template" { location = { panorama = {} } - name = format("%s-secgroup-tmpl1", var.prefix) + name = format("%s-tmpl", var.prefix) } resource "panos_device_group" "dg" { location = { panorama = {} } - name = format("%s-secgroup-dg1", var.prefix) + name = format("%s-dg", var.prefix) templates = [ resource.panos_template.template.name ] } @@ -192,6 +239,27 @@ resource "panos_security_policy" "policy" { } ` +func TestAccSecurityPolicyDuplicatedPlan(t *testing.T) { + t.Parallel() + + nameSuffix := acctest.RandStringFromCharSet(6, acctest.CharSetAlphaNum) + prefix := fmt.Sprintf("test-acc-%s", nameSuffix) + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ProtoV6ProviderFactories: testAccProviders, + Steps: []resource.TestStep{ + { + Config: securityPolicyDuplicatedTmpl, + ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), + }, + ExpectError: regexp.MustCompile("List entries must have unique names"), + }, + }, + }) +} + func TestAccSecurityPolicyExtended(t *testing.T) { t.Parallel() @@ -399,11 +467,11 @@ func TestAccSecurityPolicyExtended(t *testing.T) { } const securityPolicyOrderingTmpl = ` +variable "prefix" { type = string } variable "rule_names" { type = list(string) } -variable "location" { type = map } resource "panos_security_policy" "policy" { - location = var.location + location = { device_group = { name = format("%s-dg", var.prefix) }} rules = [ for index, name in var.rule_names: { @@ -444,10 +512,6 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { return result } - device := devicePanorama - - sdkLocation, cfgLocation := securityPolicyLocationByDeviceType(device, "pre-rulebase") - stateExpectedRuleName := func(idx int, value string) statecheck.StateCheck { return statecheck.ExpectKnownValue( "panos_security_policy.policy", @@ -467,24 +531,24 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) - securityPolicyPreCheck(prefix, sdkLocation) + securityPolicyPreCheck(prefix) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: securityPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: securityPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, PlanOnly: true, ExpectNonEmptyPlan: false, @@ -492,8 +556,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, }, ConfigStateChecks: []statecheck.StateCheck{ stateExpectedRuleName(0, "rule-1"), @@ -501,15 +565,15 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { stateExpectedRuleName(2, "rule-3"), stateExpectedRuleName(3, "rule-4"), stateExpectedRuleName(4, "rule-5"), - ExpectServerSecurityRulesCount(prefix, sdkLocation, len(rulesInitial)), - ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesInitial), + ExpectServerSecurityRulesCount(prefix, len(rulesInitial)), + ExpectServerSecurityRulesOrder(prefix, rulesInitial), }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -520,8 +584,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesReordered)...), - "location": cfgLocation, }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -538,21 +602,21 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { stateExpectedRuleName(2, "rule-3"), stateExpectedRuleName(3, "rule-4"), stateExpectedRuleName(4, "rule-5"), - ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesReordered), + ExpectServerSecurityRulesOrder(prefix, rulesReordered), }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, PlanOnly: true, ExpectNonEmptyPlan: false, @@ -561,39 +625,7 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { }) } -func securityPolicyLocationByDeviceType(typ deviceType, rulebase string) (security.Location, config.Variable) { - var sdkLocation security.Location - var cfgLocation config.Variable - switch typ { - case devicePanorama: - sdkLocation = security.Location{ - Shared: &security.SharedLocation{ - Rulebase: rulebase, - }, - } - cfgLocation = config.ObjectVariable(map[string]config.Variable{ - "shared": config.ObjectVariable(map[string]config.Variable{ - "rulebase": config.StringVariable(rulebase), - }), - }) - case deviceFirewall: - sdkLocation = security.Location{ - Vsys: &security.VsysLocation{ - NgfwDevice: "localhost.localdomain", - Vsys: "vsys1", - }, - } - cfgLocation = config.ObjectVariable(map[string]config.Variable{ - "vsys": config.ObjectVariable(map[string]config.Variable{ - "name": config.StringVariable("vsys1"), - }), - }) - } - - return sdkLocation, cfgLocation -} - -func securityPolicyPreCheck(prefix string, location security.Location) { +func securityPolicyPreCheck(prefix string) { service := security.NewService(sdkClient) ctx := context.TODO() @@ -620,8 +652,11 @@ func securityPolicyPreCheck(prefix string, location security.Location) { }, } + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + for _, elt := range rules { - _, err := service.Create(ctx, location, &elt) + _, err := service.Create(ctx, *location, &elt) if err != nil { panic(fmt.Sprintf("natPolicyPreCheck failed: %s", err)) } @@ -629,12 +664,15 @@ func securityPolicyPreCheck(prefix string, location security.Location) { } } -func securityPolicyCheckDestroy(prefix string, location security.Location) func(s *terraform.State) error { +func securityPolicyCheckDestroy(prefix string) func(s *terraform.State) error { return func(s *terraform.State) error { service := security.NewService(sdkClient) ctx := context.TODO() - rules, err := service.List(ctx, location, "get", "", "") + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + + rules, err := service.List(ctx, *location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { return err } @@ -648,7 +686,7 @@ func securityPolicyCheckDestroy(prefix string, location security.Location) func( if len(danglingNames) > 0 { err := DanglingObjectsError - delErr := service.Delete(ctx, location, danglingNames...) + delErr := service.Delete(ctx, *location, danglingNames...) if delErr != nil { err = errors.Join(err, delErr) } @@ -659,47 +697,3 @@ func securityPolicyCheckDestroy(prefix string, location security.Location) func( return nil } } - -func init() { - resource.AddTestSweepers("pango_security_policy", &resource.Sweeper{ - Name: "pango_security_policy", - F: func(typ string) error { - service := security.NewService(sdkClient) - - var deviceTyp deviceType - switch typ { - case "panorama": - deviceTyp = devicePanorama - case "firewall": - deviceTyp = deviceFirewall - default: - panic("invalid device type") - } - - for _, rulebase := range []string{"pre-rulebase", "post-rulebase"} { - location, _ := securityPolicyLocationByDeviceType(deviceTyp, rulebase) - ctx := context.TODO() - objects, err := service.List(ctx, location, "get", "", "") - if err != nil && !sdkerrors.IsObjectNotFound(err) { - return fmt.Errorf("Failed to list Security Rules during sweep: %w", err) - } - - var names []string - for _, elt := range objects { - if strings.HasPrefix(elt.Name, "test-acc") { - names = append(names, elt.Name) - } - } - - if len(names) > 0 { - err = service.Delete(ctx, location, names...) - if err != nil { - return fmt.Errorf("Failed to delete Security Rules during sweep: %w", err) - } - } - } - - return nil - }, - }) -} diff --git a/pkg/translate/imports.go b/pkg/translate/imports.go index f4fa31ec..f6932929 100644 --- a/pkg/translate/imports.go +++ b/pkg/translate/imports.go @@ -45,6 +45,8 @@ func RenderImports(templateTypes ...string) (string, error) { manager.AddSdkImport("github.com/PaloAltoNetworks/pango/audit", "") case "rule": manager.AddSdkImport("github.com/PaloAltoNetworks/pango/rule", "") + case "movement": + manager.AddSdkImport("github.com/PaloAltoNetworks/pango/movement", "") case "version": manager.AddSdkImport("github.com/PaloAltoNetworks/pango/version", "") case "template": diff --git a/pkg/translate/terraform_provider/template.go b/pkg/translate/terraform_provider/template.go index 566041eb..ad8db6e6 100644 --- a/pkg/translate/terraform_provider/template.go +++ b/pkg/translate/terraform_provider/template.go @@ -275,13 +275,37 @@ func (r *{{ resourceStructName }}) Metadata(ctx context.Context, req resource.Me func (r *{{ resourceStructName }}) ValidateConfig(ctx context.Context, req resource.ValidateConfigRequest, resp *resource.ValidateConfigResponse) { {{- if HasPosition }} + { var resource {{ resourceStructName }}Model resp.Diagnostics.Append(req.Config.Get(ctx, &resource)...) if resp.Diagnostics.HasError() { return } - resource.Position.ValidateConfig(resp) + } +{{- end }} + +{{- if IsUuid }} + { + var resource {{ resourceStructName }}Model + resp.Diagnostics.Append(req.Config.Get(ctx, &resource)...) + if resp.Diagnostics.HasError() { + return + } + {{ $resourceTFStructName := printf "%s%sObject" resourceStructName ListAttribute.CamelCase }} + entries := make(map[string]struct{}) + var elements []{{ $resourceTFStructName }} + resource.{{ ListAttribute.CamelCase }}.ElementsAs(ctx, &elements, false) + + for _, elt := range elements { + entry := elt.Name.ValueString() + if _, found := entries[entry]; found { + resp.Diagnostics.AddError("Failed to validate resource", "List entries must have unique names") + return + } + entries[entry] = struct{}{} + } + } {{- end }} } @@ -494,8 +518,7 @@ if err != nil { return } {{- else if .Exhaustive }} -trueVal := true -processed, err := r.manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, rule.Position{First: &trueVal}) +processed, err := r.manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, movement.PositionFirst{}) if err != nil { resp.Diagnostics.AddError("Error during CreateMany() call", err.Error()) return @@ -1041,8 +1064,7 @@ for idx, elt := range elements { {{ $exhaustive := "sdkmanager.NonExhaustive" }} {{- if .Exhaustive }} {{ $exhaustive = "sdkmanager.Exhaustive" }} -trueValue := true -position := rule.Position{First: &trueValue} +position := movement.PositionFirst{} {{- else }} position := state.Position.CopyToPango() {{- end }} diff --git a/pkg/translate/terraform_provider/terraform_provider_file.go b/pkg/translate/terraform_provider/terraform_provider_file.go index dcbbe8d0..372c1e98 100644 --- a/pkg/translate/terraform_provider/terraform_provider_file.go +++ b/pkg/translate/terraform_provider/terraform_provider_file.go @@ -127,13 +127,16 @@ func (g *GenerateTerraformProvider) GenerateTerraformResource(resourceTyp proper } funcMap := template.FuncMap{ - "GoSDKSkipped": func() bool { return spec.GoSdkSkip }, - "IsEntry": func() bool { return spec.HasEntryName() && !spec.HasEntryUuid() }, - "HasImports": func() bool { return len(spec.Imports) > 0 }, - "IsCustom": func() bool { return spec.TerraformProviderConfig.ResourceType == properties.TerraformResourceCustom }, - "IsUuid": func() bool { return spec.HasEntryUuid() }, - "IsConfig": func() bool { return !spec.HasEntryName() && !spec.HasEntryUuid() }, - "IsImportable": func() bool { return resourceTyp == properties.ResourceEntry }, + "GoSDKSkipped": func() bool { return spec.GoSdkSkip }, + "IsEntry": func() bool { return spec.HasEntryName() && !spec.HasEntryUuid() }, + "HasImports": func() bool { return len(spec.Imports) > 0 }, + "IsCustom": func() bool { return spec.TerraformProviderConfig.ResourceType == properties.TerraformResourceCustom }, + "IsUuid": func() bool { return spec.HasEntryUuid() }, + "IsConfig": func() bool { return !spec.HasEntryName() && !spec.HasEntryUuid() }, + "IsImportable": func() bool { return resourceTyp == properties.ResourceEntry }, + "ListAttribute": func() *properties.NameVariant { + return properties.NewNameVariant(spec.TerraformProviderConfig.PluralName) + }, "resourceSDKName": func() string { return names.PackageName }, "HasPosition": func() bool { return hasPosition }, "metaName": func() string { return names.MetaName }, @@ -198,7 +201,7 @@ func (g *GenerateTerraformProvider) GenerateTerraformResource(resourceTyp proper terraformProvider.ImportManager.AddStandardImport("errors", "") switch resourceTyp { case properties.ResourceUuid: - terraformProvider.ImportManager.AddSdkImport("github.com/PaloAltoNetworks/pango/rule", "") + terraformProvider.ImportManager.AddSdkImport("github.com/PaloAltoNetworks/pango/movement", "") case properties.ResourceEntry: case properties.ResourceUuidPlural: case properties.ResourceEntryPlural: diff --git a/templates/sdk/service.tmpl b/templates/sdk/service.tmpl index 9331048a..141e269f 100644 --- a/templates/sdk/service.tmpl +++ b/templates/sdk/service.tmpl @@ -2,13 +2,13 @@ package {{packageName .GoSdkPath}} {{- if .Entry}} {{- if $.Imports}} {{- if $.Spec.Params.uuid}} - {{renderImports "service" "filtering" "sync" "audit" "rule" "version"}} + {{renderImports "service" "filtering" "sync" "audit" "rule" "version" "movement"}} {{- else}} {{renderImports "service" "filtering" "sync"}} {{- end}} {{- else}} {{- if $.Spec.Params.uuid}} - {{renderImports "service" "filtering" "audit" "rule" "version"}} + {{renderImports "service" "filtering" "audit" "movement"}} {{- else}} {{renderImports "service" "filtering"}} {{- end}} @@ -803,396 +803,62 @@ func (s *Service) RemoveFromImport(ctx context.Context, loc Location, entry Entr // MoveGroup arranges the given rules in the order specified. // Any rule with a UUID specified is ignored. // Only the rule names are considered for the purposes of the rule placement. - func (s *Service) MoveGroup(ctx context.Context, loc Location, position rule.Position, entries []*Entry) error { + func (s *Service) MoveGroup(ctx context.Context, loc Location, position movement.Position, entries []*Entry) error { if len(entries) == 0 { - return nil + return nil } - listing, err := s.List(ctx, loc, "get", "", "") + existing, err := s.List(ctx, loc, "get", "", "") if err != nil { - return err - } else if len(listing) == 0 { - return fmt.Errorf("no rules present") - } - - rp := make(map[string]int) - for idx, live := range listing { - rp[live.Name] = idx + return err + } else if len(existing) == 0 { + return fmt.Errorf("no rules present") } - vn := s.client.Versioning() - updates := xmlapi.NewMultiConfig(len(entries)) - - var ok, topDown bool - var otherIndex int - baseIndex := -1 - switch { - case position.First != nil && *position.First: - topDown, baseIndex, ok, err = s.moveTop(topDown, entries, baseIndex, ok, rp, loc, vn, updates) - if err != nil { - return err - } - case position.Last != nil && *position.Last: - baseIndex, ok, err = s.moveBottom(entries, baseIndex, ok, rp, listing, loc, vn, updates) - if err != nil { - return err - } - case position.SomewhereAfter != nil && *position.SomewhereAfter != "": - topDown, baseIndex, ok, otherIndex, err = s.moveSomewhereAfter(topDown, entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.SomewhereBefore != nil && *position.SomewhereBefore != "": - baseIndex, ok, otherIndex, err = s.moveSomewhereBefore(entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.DirectlyAfter != nil && *position.DirectlyAfter != "": - topDown, baseIndex, ok, otherIndex, err = s.moveDirectlyAfter(topDown, entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.DirectlyBefore != nil && *position.DirectlyBefore != "": - baseIndex, ok, err = s.moveDirectlyBefore(entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - default: - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return fmt.Errorf("could not find rule %q for first positioning", target.Name) - } - } + movements, err := movement.MoveGroup(position, entries, existing) + if err != nil { + return err + } - var prevName, where string - if topDown { - prevName = entries[0].Name - where = "after" - } else { - prevName = entries[len(entries)-1].Name - where = "before" - } + updates := xmlapi.NewMultiConfig(len(movements)) + + for _, elt := range movements { + path, err := loc.XpathWithEntryName(s.client.Versioning(), elt.Movable.EntryName()) + if err != nil { + return err + } + + switch elt.Where { + case movement.ActionWhereFirst, movement.ActionWhereLast: + updates.Add(&xmlapi.Config{ + Action: "move", + Xpath: util.AsXpath(path), + Where: string(elt.Where), + Destination: string(elt.Where), + Target: s.client.GetTarget(), + }) + case movement.ActionWhereBefore, movement.ActionWhereAfter: + updates.Add(&xmlapi.Config{ + Action: "move", + Xpath: util.AsXpath(path), + Where: string(elt.Where), + Destination: elt.Destination.EntryName(), + Target: s.client.GetTarget(), + }) + } - for i := 1; i < len(entries); i++ { - err := s.moveRestOfRules(topDown, entries, i, baseIndex, rp, loc, vn, updates, where, prevName) - if err != nil { - return err - } - } + } if len(updates.Operations) > 0 { - _, _, _, err = s.client.MultiConfig(ctx, updates, false, nil) - return err - } - - return nil - } - - func (s *Service) moveRestOfRules(topDown bool, entries []*Entry, i int, baseIndex int, rp map[string]int, loc Location, vn version.Number, updates *xmlapi.MultiConfig, where string, prevName string) error { - var target Entry - var desiredIndex int - if topDown { - target = *entries[i] - desiredIndex = baseIndex + i - } else { - target = *entries[len(entries)-1-i] - desiredIndex = baseIndex - i - } - - idx, ok := rp[target.Name] - if !ok { - return fmt.Errorf("rule %q not present", target.Name) - } - - if idx != desiredIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return err - } - - if idx < desiredIndex { - for name, val := range rp { - if val > idx && val <= desiredIndex { - rp[name] = val - 1 - } - } - } else { - for name, val := range rp { - if val < idx && val >= desiredIndex { - rp[name] = val + 1 - } - } - } - rp[target.Name] = desiredIndex - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: where, - Destination: prevName, - Target: s.client.GetTarget(), - }) - } - - prevName = target.Name - return nil - } - - func (s *Service) moveDirectlyBefore(entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.DirectlyBefore] - if !ok { - return 0, false, fmt.Errorf("could not find referenced rule %q", *position.DirectlyBefore) - } - - if baseIndex+1 != otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val < baseIndex && val >= otherIndex: - rp[name] = val + 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "before", - Destination: *position.DirectlyBefore, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return baseIndex, ok, nil - } - - func (s *Service) moveDirectlyAfter(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, int, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.DirectlyAfter] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find referenced rule %q for initial positioning", *position.DirectlyAfter) + _, _, _, err = s.client.MultiConfig(ctx, updates, false, nil) + return err } - if baseIndex != otherIndex+1 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val > baseIndex && val <= otherIndex: - rp[name] = otherIndex - 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "after", - Destination: *position.DirectlyAfter, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return topDown, baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveSomewhereBefore(entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, int, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.SomewhereBefore] - if !ok { - return 0, false, 0, fmt.Errorf("could not find referenced rule %q", *position.SomewhereBefore) - } - - if baseIndex > otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val < baseIndex && val >= otherIndex: - rp[name] = val + 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "before", - Destination: *position.SomewhereBefore, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveSomewhereAfter(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, int, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.SomewhereAfter] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find referenced rule %q for initial positioning", *position.SomewhereAfter) - } - - if baseIndex < otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val > baseIndex && val <= otherIndex: - rp[name] = otherIndex - 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "after", - Destination: *position.SomewhereAfter, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return topDown, baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveBottom(entries []*Entry, baseIndex int, ok bool, rp map[string]int, listing []*Entry, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, fmt.Errorf("could not find rule %q for last positioning", target.Name) - } - - if baseIndex != len(listing)-1 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = len(listing) - 1 - case val > baseIndex: - rp[name] = val - 1 - } - } - - // some versions of PAN-OS require that the destination always be set - var dst string - if !vn.Gte(util.FixedPanosVersionForMultiConfigMove) { - dst = "bottom" - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "bottom", - Destination: dst, - Target: s.client.GetTarget(), - }) - - baseIndex = len(listing) - 1 - } - return baseIndex, ok, nil - } - - func (s *Service) moveTop(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, fmt.Errorf("could not find rule %q for first positioning", target.Name) - } - - if baseIndex != 0 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, err - } - - for name, val := range rp { - switch { - case name == entries[0].Name: - rp[name] = 0 - case val < baseIndex: - rp[name] = val + 1 - } - } - - // some versions of PAN-OS require that the destination always be set - var dst string - if !vn.Gte(util.FixedPanosVersionForMultiConfigMove) { - dst = "top" - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "top", - Destination: dst, - Target: s.client.GetTarget(), - }) + return nil +} - baseIndex = 0 - } - return topDown, baseIndex, ok, nil - } - // HitCount returns the hit count for the given rule. + // HITCOUNT returns the hit count for the given rule. func (s *Service) HitCount(ctx context.Context, loc Location, rules ...string) ([]util.HitCount, error) { switch { case loc.Vsys != nil: