Skip to content

Commit

Permalink
compile: Add support for unknowns other than input
Browse files Browse the repository at this point in the history
Previously the build command and the compile package would only mark
input as unknown. If documents under data needed to be treated as
unknown, there was no solution. This commit updates the compile
package to infer unknowns based on the bundle roots. If the policy
refers to a data document _outside_ of one of the bundle roots, that
data document will be marked as unknown during optimization/partial
eval.

Fixes #2581

Signed-off-by: Torin Sandall <[email protected]>
  • Loading branch information
tsandall committed Aug 17, 2020
1 parent 0d54ba2 commit fda63bd
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 4 deletions.
94 changes: 94 additions & 0 deletions compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
initload "github.com/open-policy-agent/opa/internal/runtime/init"
"github.com/open-policy-agent/opa/loader"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem"
)

Expand Down Expand Up @@ -428,6 +429,7 @@ func (o *optimizer) Do(ctx context.Context) error {
store := inmem.NewFromObject(data)
resultsym := ast.VarTerm(o.resultsymprefix + "__result__")
usedFilenames := map[string]int{}
var unknowns []*ast.Term

// NOTE(tsandall): the entrypoints are optimized in order so that the optimization
// of entrypoint[1] sees the optimization of entrypoint[0] and so on. This is needed
Expand All @@ -442,12 +444,17 @@ func (o *optimizer) Do(ctx context.Context) error {
return err
}

if unknowns == nil {
unknowns = o.findUnknowns()
}

r := rego.New(
rego.ParsedQuery(ast.NewBody(ast.Equality.Expr(resultsym, e))),
rego.PartialNamespace(o.nsprefix),
rego.DisableInlining(o.findRequiredDocuments(e)),
rego.ShallowInlining(o.shallow),
rego.SkipPartialNamespace(true),
rego.ParsedUnknowns(unknowns),
rego.Compiler(o.compiler),
rego.Store(store),
)
Expand Down Expand Up @@ -539,6 +546,35 @@ func (o *optimizer) findRequiredDocuments(ref *ast.Term) []string {
return result
}

func (o *optimizer) findUnknowns() []*ast.Term {

// Initialize set of refs representing the bundle roots.
refs := newRefSet(stringsToRefs(*o.bundle.Manifest.Roots)...)

// Initialize set of refs for the result (i.e., refs outside the bundle roots.)
unknowns := newRefSet(ast.InputRootRef)

// Find data references that are not prefixed by one of the roots.
for _, module := range o.compiler.Modules {
ast.WalkRefs(module, func(x ast.Ref) bool {
prefix := x.ConstantPrefix()
if !prefix.HasPrefix(ast.DefaultRootRef) {
return true
}
if !refs.ContainsPrefix(prefix) {
o.debug.Add(Debug{
Location: x[0].Location,
Message: fmt.Sprintf("marking %v as unknown", prefix),
})
unknowns.AddPrefix(prefix)
}
return false
})
}

return unknowns.Sorted()
}

func (o *optimizer) getSupportForEntrypoint(queries []ast.Body, e *ast.Term, resultsym *ast.Term) *ast.Module {

path := e.Value.(ast.Ref)
Expand Down Expand Up @@ -697,3 +733,61 @@ func (ss orderedStringSet) Append(s ...string) orderedStringSet {
}
return ss
}

func stringsToRefs(x []string) []ast.Ref {
result := make([]ast.Ref, len(x))
for i := range result {
result[i] = storage.MustParsePath("/" + x[i]).Ref(ast.DefaultRootDocument)
}
return result
}

type refSet struct {
s []ast.Ref
}

func newRefSet(x ...ast.Ref) *refSet {
result := &refSet{}
for i := range x {
result.AddPrefix(x[i])
}
return result
}

// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set.
func (rs *refSet) ContainsPrefix(r ast.Ref) bool {
for i := range rs.s {
if r.HasPrefix(rs.s[i]) {
return true
}
}
return false
}

// AddPrefix inserts r into the set if r is not prefixed by any existing
// refs in the set. If any existing refs are prefixed by r, those existing
// refs are removed.
func (rs *refSet) AddPrefix(r ast.Ref) {
if rs.ContainsPrefix(r) {
return
}
cpy := []ast.Ref{r}
for i := range rs.s {
if !rs.s[i].HasPrefix(r) {
cpy = append(cpy, rs.s[i])
}
}
rs.s = cpy
}

// Sorted returns a sorted slice of terms for refs in the set.
func (rs *refSet) Sorted() []*ast.Term {
terms := make([]*ast.Term, len(rs.s))
for i := range rs.s {
terms[i] = ast.NewTerm(rs.s[i])
}
sort.Slice(terms, func(i, j int) bool {
return terms[i].Value.Compare(terms[j].Value) < 0
})
return terms
}
101 changes: 97 additions & 4 deletions compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func TestOptimizerNoops(t *testing.T) {

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
o := getOptimizer(tc.modules, "", tc.entrypoints)
o := getOptimizer(tc.modules, "", tc.entrypoints, nil)
cpy := o.bundle.Copy()
err := o.Do(context.Background())
if err != nil {
Expand Down Expand Up @@ -560,7 +560,7 @@ func TestOptimizerErrors(t *testing.T) {

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
o := getOptimizer(tc.modules, "", tc.entrypoints)
o := getOptimizer(tc.modules, "", tc.entrypoints, nil)
cpy := o.bundle.Copy()
got := o.Do(context.Background())
if got == nil || got.Error() != tc.wantErr.Error() {
Expand Down Expand Up @@ -808,12 +808,44 @@ func TestOptimizerOutput(t *testing.T) {
`,
},
},
{
note: "infer unknowns from roots",
entrypoints: []string{"data.test.p"},
modules: map[string]string{
"test.rego": `
package test
p {
q[x]
data.external.users[x] == input.user
}
q["foo"]
q["bar"]
`,
},
roots: []string{"test"},
wantModules: map[string]string{
"optimized/test.rego": `
package test
p = __result__ { data.external.users.foo = input.user; __result__ = true }
p = __result__ { data.external.users.bar = input.user; __result__ = true }
`,
"test.rego": `
package test
q["foo"]
q["bar"]
`,
},
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {

o := getOptimizer(tc.modules, tc.data, tc.entrypoints)
o := getOptimizer(tc.modules, tc.data, tc.entrypoints, tc.roots)
original := o.bundle.Copy()
err := o.Do(context.Background())
if err != nil {
Expand All @@ -825,6 +857,11 @@ func TestOptimizerOutput(t *testing.T) {
Data: original.Data, // data is not pruned at all today
}

if len(tc.roots) > 0 {
exp.Manifest.Roots = &tc.roots
exp.Manifest.AddRoot("partial") // optimizer will add this automatically
}

exp.Manifest.Revision = "" // optimizations must reset the revision.

if !exp.Equal(*o.bundle) {
Expand All @@ -838,7 +875,59 @@ func TestOptimizerOutput(t *testing.T) {
}
}

func getOptimizer(modules map[string]string, data string, entries []string) *optimizer {
func TestRefSet(t *testing.T) {
rs := newRefSet(ast.MustParseRef("input"), ast.MustParseRef("data.foo.bar"))

expFound := []string{
"input",
"input.foo",
"data.foo.bar",
"data.foo.bar.baz",
"data.foo.bar[1]",
}

for _, exp := range expFound {
if !rs.ContainsPrefix(ast.MustParseRef(exp)) {
t.Fatal("expected to find:", exp)
}
}

expNotFound := []string{
"x.bar",
"data",
"data.bar",
"data.foo",
}

for _, exp := range expNotFound {
if rs.ContainsPrefix(ast.MustParseRef(exp)) {
t.Fatal("expected not to find:", exp)
}
}

rs.AddPrefix(ast.MustParseRef("data.foo"))

if !rs.ContainsPrefix(ast.MustParseRef("data.foo")) {
t.Fatal("expected to find data.foo after adding to set")
}

sorted := rs.Sorted()

if len(sorted) != 2 || !sorted[0].Equal(ast.MustParseTerm("data.foo")) || !sorted[1].Equal(ast.MustParseTerm("input")) {
t.Fatal("expected 2 prefixes (data.foo and input) but got:", sorted)
}

// The prefixes should not be affected (because data.foo already exists).
rs.AddPrefix(ast.MustParseRef("data.foo.qux"))
sorted = rs.Sorted()

if len(sorted) != 2 || !sorted[0].Equal(ast.MustParseTerm("data.foo")) || !sorted[1].Equal(ast.MustParseTerm("input")) {
t.Fatal("expected 2 prefixes (data.foo and input) but got:", sorted)
}

}

func getOptimizer(modules map[string]string, data string, entries []string, roots []string) *optimizer {

b := &bundle.Bundle{
Modules: getModuleFiles(modules, true),
Expand All @@ -848,6 +937,10 @@ func getOptimizer(modules map[string]string, data string, entries []string) *opt
b.Data = util.MustUnmarshalJSON([]byte(data)).(map[string]interface{})
}

if len(roots) > 0 {
b.Manifest.Roots = &roots
}

b.Manifest.Init()
b.Manifest.Revision = "DEADBEEF" // ensures that the manifest revision is getting reset
entrypoints := make([]*ast.Term, len(entries))
Expand Down
4 changes: 4 additions & 0 deletions docs/content/policy-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,10 @@ Rules that depend on unknowns (directly or indirectly) are also partially evalua
virtual documents they produce ARE NOT inlined into call sites. The output policy should be structurally
similar to the input policy.

The `opa build` automatically marks the `input` document as unknown. In addition to the `input` document,
if `opa build` is invoked with the `-b`/`--bundle` flag, any `data` references NOT prefixed by the
`.manifest` roots are also marked as unknown.

### -O=2 (aggressive)

Same as `-O=1` except virtual documents produced by rules that depend on unknowns may be inlined
Expand Down

0 comments on commit fda63bd

Please sign in to comment.