diff --git a/pkg/go/graph/graph.go b/pkg/go/graph/graph.go index b71d4142..03afb810 100644 --- a/pkg/go/graph/graph.go +++ b/pkg/go/graph/graph.go @@ -2,6 +2,11 @@ package graph import ( "errors" + "reflect" + "slices" + "sort" + + "github.com/openfga/language/pkg/go/utils" "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/encoding" @@ -41,14 +46,59 @@ func (g *AuthorizationModelGraph) GetDOT() string { return string(dotRepresentation) } +func sortAndRemoveDuplicateAndAlgebraic(orig [][]string) [][]string { + newSlices := make(utils.SlicesOfSlices, 0, len(orig)) + for _, slice := range orig { + newSlice := make([]string, 0, len(slice)) + for _, item := range slice { + if item != union && item != intersection && item != exclusion && !slices.Contains(newSlice, item) { + newSlice = append(newSlice, item) + } + } + sort.Strings(newSlice) + if !slicesOfSliceContains(newSlices, newSlice) { + newSlices = append(newSlices, newSlice) + } + } + // now, sort the slices according to size + sort.Sort(newSlices) + + return newSlices +} + // CycleInformation encapsulates whether the graph has cycles. type CycleInformation struct { - // If hasCyclesAtCompileTime is true, we should block this model from ever being written. + // If hasCyclesAtCompileTime is non-empty, we should block this model from ever being written. // This is because we are trying to perform a Check on it will cause a stack overflow no matter what the tuples are. - hasCyclesAtCompileTime bool + hasCyclesAtCompileTime [][]string - // If canHaveCyclesAtRuntime is true, there could exist tuples that introduce a cycle. - canHaveCyclesAtRuntime bool + // If canHaveCyclesAtRuntime is non-empty, there could exist tuples that introduce a cycle. + canHaveCyclesAtRuntime [][]string +} + +// slicesOfSliceContains returns true if the newSlice is deeply equal to any of the origSlices. +func slicesOfSliceContains(origSlices [][]string, newSlice []string) bool { + for _, origSlice := range origSlices { + if reflect.DeepEqual(newSlice, origSlice) { + return true + } + } + + return false +} + +// SortedHasCyclesAtCompileTime returns a sorted HasCyclesAtCompileTime which removed algebraic operation (such as union/intersection/exclusion) +// The []string are sorted by length. If []string has the same length, it will return if the first/second/third/.. item is smallest +// Within each []string, it is sorted by alphabet. In addition, the duplicate node is removed. +func (c *CycleInformation) SortedHasCyclesAtCompileTime() [][]string { + return sortAndRemoveDuplicateAndAlgebraic(c.hasCyclesAtCompileTime) +} + +// SortedCanHaveCyclesAtRuntime returns a sorted HasCyclesAtCompileTime which removed algebraic operation (such as union/intersection/exclusion) +// The []string are sorted by length. If []string has the same length, it will return if the first/second/third/.. item is smallest +// Within each []string, it is sorted by alphabet. In addition, the duplicate node is removed. +func (c *CycleInformation) SortedCanHaveCyclesAtRuntime() [][]string { + return sortAndRemoveDuplicateAndAlgebraic(c.canHaveCyclesAtRuntime) } func (g *AuthorizationModelGraph) nodeListHasNonComputedEdge(nodeList []graph.Node) bool { @@ -67,24 +117,36 @@ func (g *AuthorizationModelGraph) nodeListHasNonComputedEdge(nodeList []graph.No return false } +func nodeListIdentifier(nodeList []graph.Node) []string { + labels := make([]string, 0, len(nodeList)) + for _, node := range nodeList { + auth, ok := node.(*AuthorizationModelNode) + if ok { + labels = append(labels, auth.label) + } + } + + return labels +} + func (g *AuthorizationModelGraph) GetCycles() CycleInformation { - hasCyclesAtCompileTime := false - hasCyclesAtRuntime := false + var nodesWithCyclesAtCompileTime [][]string + var nodesWithCyclesAtRuntime [][]string // TODO: investigate whether len(1) should be identified as cycle nodes := topo.DirectedCyclesIn(g) for _, nodeList := range nodes { if g.nodeListHasNonComputedEdge(nodeList) { - hasCyclesAtRuntime = true + nodesWithCyclesAtRuntime = append(nodesWithCyclesAtRuntime, nodeListIdentifier(nodeList)) } else { - hasCyclesAtCompileTime = true + nodesWithCyclesAtCompileTime = append(nodesWithCyclesAtCompileTime, nodeListIdentifier(nodeList)) } } return CycleInformation{ - hasCyclesAtCompileTime: hasCyclesAtCompileTime, - canHaveCyclesAtRuntime: hasCyclesAtRuntime, + hasCyclesAtCompileTime: nodesWithCyclesAtCompileTime, + canHaveCyclesAtRuntime: nodesWithCyclesAtRuntime, } } diff --git a/pkg/go/graph/graph_builder.go b/pkg/go/graph/graph_builder.go index 0310a13f..b047cb8a 100644 --- a/pkg/go/graph/graph_builder.go +++ b/pkg/go/graph/graph_builder.go @@ -11,6 +11,12 @@ import ( "gonum.org/v1/gonum/graph/multi" ) +const ( + union = "union" + intersection = "intersection" + exclusion = "exclusion" +) + type AuthorizationModelGraphBuilder struct { graph.DirectedMultigraphBuilder @@ -92,15 +98,15 @@ func checkRewrite(graphBuilder *AuthorizationModelGraphBuilder, parentNode *Auth return case *openfgav1.Userset_Union: - operator = "union" + operator = union children = rw.Union.GetChild() case *openfgav1.Userset_Intersection: - operator = "intersection" + operator = intersection children = rw.Intersection.GetChild() case *openfgav1.Userset_Difference: - operator = "exclusion" + operator = exclusion children = []*openfgav1.Userset{ rw.Difference.GetBase(), rw.Difference.GetSubtract(), diff --git a/pkg/go/graph/graph_builder_test.go b/pkg/go/graph/graph_builder_test.go index 1c25be22..3a8d5b83 100644 --- a/pkg/go/graph/graph_builder_test.go +++ b/pkg/go/graph/graph_builder_test.go @@ -227,8 +227,8 @@ rankdir=BT 3 -> 2 [style=dashed]; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: true, - canHaveCyclesAtRuntime: false, + hasCyclesAtCompileTime: [][]string{{"folder#x", "folder#y", "folder#z", "folder#x"}}, + canHaveCyclesAtRuntime: nil, }, }, `computed_relation_with_size_two`: { @@ -254,8 +254,8 @@ rankdir=BT 2 -> 1 [style=dashed]; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: true, - canHaveCyclesAtRuntime: false, + hasCyclesAtCompileTime: [][]string{{"folder#x", "folder#y", "folder#x"}}, + canHaveCyclesAtRuntime: nil, }, }, `tuple_to_userset_one_related_type`: { @@ -317,8 +317,8 @@ rankdir=BT 4 -> 3 [label=direct]; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{{"folder#viewer"}}, }, }, `tuple_to_userset_two_related_types`: { @@ -631,8 +631,8 @@ rankdir=BT 7 -> 6; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{{"folder#a", "folder#b", "folder#c"}}, }, }, `multigraph`: { @@ -820,8 +820,8 @@ rankdir=BT 9 -> 4 [style=dashed]; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: true, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: [][]string{{"folder#x", "folder#y"}}, + canHaveCyclesAtRuntime: [][]string{{"folder#a", "folder#b", "folder#c"}}, }, }, `potential_cycle_or_but_not`: { @@ -861,8 +861,8 @@ rankdir=BT 7 -> 6; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{{"resource#x", "resource#y", "resource#z"}}, }, }, `potential_cycle_four_union`: { @@ -916,8 +916,20 @@ rankdir=BT 9 -> 6; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{ + {"group#member", "group#memberA"}, + {"group#member", "group#memberB"}, + {"group#member", "group#memberC"}, + {"group#memberA", "group#memberB"}, + {"group#memberA", "group#memberC"}, + {"group#memberB", "group#memberC"}, + {"group#member", "group#memberA", "group#memberB"}, + {"group#member", "group#memberA", "group#memberC"}, + {"group#member", "group#memberB", "group#memberC"}, + {"group#memberA", "group#memberB", "group#memberC"}, + {"group#member", "group#memberA", "group#memberB", "group#memberC"}, + }, }, }, `potential_cycle_four_union_with_one_member_no_union`: { @@ -966,8 +978,13 @@ rankdir=BT 8 -> 5; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{ + {"account#admin", "account#member"}, + {"account#admin", "account#super_admin"}, + {"account#member", "account#super_admin"}, + {"account#admin", "account#member", "account#super_admin"}, + }, }, }, `intersection`: { @@ -1014,8 +1031,13 @@ rankdir=BT 8 -> 3 [label=direct]; }`, cycleInformation: CycleInformation{ - hasCyclesAtCompileTime: false, - canHaveCyclesAtRuntime: true, + hasCyclesAtCompileTime: nil, + canHaveCyclesAtRuntime: [][]string{ + {"document#action1", "document#action2"}, + {"document#action1", "document#action3"}, + {"document#action2", "document#action3"}, + {"document#action1", "document#action2", "document#action3"}, + }, }, }, } @@ -1035,7 +1057,17 @@ rankdir=BT diff := cmp.Diff(expectedSorted, actualSorted) require.Empty(t, diff, "expected %s\ngot\n%s", testCase.expectedOutput, actualDOT) - require.Equal(t, testCase.cycleInformation, graph.GetCycles()) + cycleInfo := graph.GetCycles() + if len(testCase.cycleInformation.canHaveCyclesAtRuntime) > 0 { + require.Equal(t, testCase.cycleInformation.SortedCanHaveCyclesAtRuntime(), cycleInfo.SortedCanHaveCyclesAtRuntime()) + } else { + require.Empty(t, cycleInfo.SortedCanHaveCyclesAtRuntime()) + } + if len(testCase.cycleInformation.hasCyclesAtCompileTime) > 0 { + require.Equal(t, testCase.cycleInformation.SortedHasCyclesAtCompileTime(), cycleInfo.SortedHasCyclesAtCompileTime()) + } else { + require.Empty(t, cycleInfo.SortedHasCyclesAtCompileTime()) + } }) } } @@ -1114,3 +1146,90 @@ func getSorted(input string) string { return strings.Join(lines, "\n") } + +func TestSortAndRemoveDuplicateAndAlgebraic(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input [][]string + expected [][]string + }{ + { + name: "empty", + input: [][]string{}, + expected: [][]string{}, + }, + { + name: "single_line_with_no_duplicate", + input: [][]string{ + {"c", "b", "a"}, + }, + expected: [][]string{ + {"a", "b", "c"}, + }, + }, + { + name: "single_line_with_duplicate_union_intersection_exclusion", + input: [][]string{ + {"c", "exclusion", "b", "union", "c", "a", "intersection", "a", "a"}, + }, + expected: [][]string{ + {"a", "b", "c"}, + }, + }, + { + name: "multiple_line_with_no_duplicate", + input: [][]string{ + {"y", "x", "z"}, + {"c", "b", "a"}, + {"a", "c", "d", "b"}, + }, + expected: [][]string{ + {"a", "b", "c"}, + {"x", "y", "z"}, + {"a", "b", "c", "d"}, + }, + }, + { + name: "multiple_line_with_duplicate_items", + input: [][]string{ + {"y", "x", "x", "z"}, + {"c", "b", "a", "a"}, + {"a", "c", "d", "b"}, + }, + expected: [][]string{ + {"a", "b", "c"}, + {"x", "y", "z"}, + {"a", "b", "c", "d"}, + }, + }, + { + name: "multiple_line_difference_last_item", + input: [][]string{ + {"c", "b", "a", "e"}, + {"a", "c", "d", "b"}, + }, + expected: [][]string{ + {"a", "b", "c", "d"}, + {"a", "b", "c", "e"}, + }, + }, + { + name: "duplicate_lines", + input: [][]string{ + {"c", "b", "a", "d"}, + {"a", "c", "d", "b"}, + }, + expected: [][]string{ + {"a", "b", "c", "d"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + output := sortAndRemoveDuplicateAndAlgebraic(tt.input) + require.Equal(t, tt.expected, output) + }) + } +} diff --git a/pkg/go/utils/slices_of_slices.go b/pkg/go/utils/slices_of_slices.go new file mode 100644 index 00000000..41da847e --- /dev/null +++ b/pkg/go/utils/slices_of_slices.go @@ -0,0 +1,29 @@ +package utils + +type SlicesOfSlices [][]string + +func (s SlicesOfSlices) Len() int { return len(s) } + +func (s SlicesOfSlices) Less(i, j int) bool { + if len(s[i]) < len(s[j]) { + return true + } + if len(s[i]) > len(s[j]) { + return false + } + // the length is equal, sort according to item(from first to last) + for k := range s[i] { + if s[i][k] < s[j][k] { + return true + } + if s[i][k] > s[j][k] { + return false + } + } + + return true +} + +func (s SlicesOfSlices) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} diff --git a/pkg/go/utils/slices_of_slices_test.go b/pkg/go/utils/slices_of_slices_test.go new file mode 100644 index 00000000..44d1a0a8 --- /dev/null +++ b/pkg/go/utils/slices_of_slices_test.go @@ -0,0 +1,80 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSlicesOfSlices_Less(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input SlicesOfSlices + want bool + }{ + { + name: "first_slice_longer", + input: SlicesOfSlices{ + {"a", "b"}, + {}, + }, + want: false, + }, + { + name: "second_slice_longer", + input: SlicesOfSlices{ + {}, + {"a", "b"}, + }, + want: true, + }, + { + name: "equal_length_first_slice_smaller_element", + input: SlicesOfSlices{ + {"a", "b"}, + {"x", "b"}, + }, + want: true, + }, + { + name: "equal_length_first_slice_larger_element", + input: SlicesOfSlices{ + {"x", "b"}, + {"a", "b"}, + }, + want: false, + }, + { + name: "equal_length_first_slice_smaller_last_element", + input: SlicesOfSlices{ + {"a", "b"}, + {"a", "c"}, + }, + want: true, + }, + { + name: "equal_length_second_slice_smaller_last_element", + input: SlicesOfSlices{ + {"a", "c"}, + {"a", "b"}, + }, + want: false, + }, + { + name: "equal_everything", + input: SlicesOfSlices{ + {"a", "b"}, + {"a", "b"}, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + output := tt.input.Less(0, 1) + require.Equal(t, tt.want, output) + }) + } +}