diff --git a/internal/storage/mock/mock.go b/internal/storage/mock/mock.go new file mode 100644 index 0000000000..a97cc67066 --- /dev/null +++ b/internal/storage/mock/mock.go @@ -0,0 +1,240 @@ +// Copyright 2019 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package mock defines a fake storage implementation for use in testing. +package mock + +import ( + "context" + "fmt" + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/storage/inmem" + "testing" +) + +// Transaction is a mock storage.Transaction implementation for use in testing. +// It uses an internal storage.Transaction pointer with some added functionality. +type Transaction struct { + txn storage.Transaction + Committed int + Aborted int +} + +// ID returns the underlying transaction ID +func (t *Transaction) ID() uint64 { + return t.txn.ID() +} + +// Validate returns an error if the transaction is in an invalid state +func (t *Transaction) Validate() error { + if t.Committed > 1 { + return fmt.Errorf("transaction %d has too many commits (%d)", t.ID(), t.Committed) + } + if t.Aborted > 1 { + return fmt.Errorf("transaction %d has too many aborts (%d)", t.ID(), t.Committed) + } + return nil +} + +func (t *Transaction) safeToUse() bool { + return t.Committed == 0 && t.Aborted == 0 +} + +// Store is a mock storage.Store implementation for use in testing. +type Store struct { + inmem storage.Store + Transactions []*Transaction + Reads []*ReadCall + Writes []*WriteCall +} + +// ReadCall captures the parameters for a Read call +type ReadCall struct { + Transaction *Transaction + Path storage.Path + Error error + Safe bool +} + +// WriteCall captures the parameters for a write call +type WriteCall struct { + Transaction *Transaction + Op storage.PatchOp + Path storage.Path + Error error + Safe bool +} + +// New creates a new mock Store +func New() *Store { + s := &Store{} + s.Reset() + return s +} + +// Reset the store +func (s *Store) Reset() { + s.Transactions = []*Transaction{} + s.Reads = []*ReadCall{} + s.Writes = []*WriteCall{} + s.inmem = inmem.New() +} + +// GetTransaction will a transaction with a specific ID +// that was associated with this Store. +func (s *Store) GetTransaction(id uint64) *Transaction { + for _, txn := range s.Transactions { + if txn.ID() == id { + return txn + } + } + return nil +} + +// Errors retuns a list of errors for each invalid state found. +// If any Transactions are invalid or reads/writes were +// unsafe an error will be returned for each problem. +func (s *Store) Errors() []error { + var errs []error + for _, txn := range s.Transactions { + err := txn.Validate() + if err != nil { + errs = append(errs, err) + } + } + + for _, read := range s.Reads { + if !read.Safe { + errs = append(errs, fmt.Errorf("unsafe Read call %+v", *read)) + } + } + + for _, write := range s.Writes { + if !write.Safe { + errs = append(errs, fmt.Errorf("unsafe Write call %+v", *write)) + } + } + + return errs +} + +// AssertValid will raise an error with the provided testing.T if +// there are any errors on the store. +func (s *Store) AssertValid(t *testing.T) { + t.Helper() + for _, err := range s.Errors() { + t.Errorf("Error detected on store: %s", err) + } +} + +// storage.Store interface implementation + +// Register just shims the call to the underlying inmem store +func (s *Store) Register(ctx context.Context, txn storage.Transaction, config storage.TriggerConfig) (storage.TriggerHandle, error) { + return s.inmem.Register(ctx, txn, config) +} + +// ListPolicies just shims the call to the underlying inmem store +func (s *Store) ListPolicies(ctx context.Context, txn storage.Transaction) ([]string, error) { + return s.ListPolicies(ctx, txn) +} + +// GetPolicy just shims the call to the underlying inmem store +func (s *Store) GetPolicy(ctx context.Context, txn storage.Transaction, name string) ([]byte, error) { + return s.inmem.GetPolicy(ctx, txn, name) +} + +// UpsertPolicy just shims the call to the underlying inmem store +func (s *Store) UpsertPolicy(ctx context.Context, txn storage.Transaction, name string, policy []byte) error { + return s.inmem.UpsertPolicy(ctx, txn, name, policy) +} + +// DeletePolicy just shims the call to the underlying inmem store +func (s *Store) DeletePolicy(ctx context.Context, txn storage.Transaction, name string) error { + return s.inmem.DeletePolicy(ctx, txn, name) +} + +// Build just shims the call to the underlying inmem store +func (s *Store) Build(ctx context.Context, txn storage.Transaction, ref ast.Ref) (storage.Index, error) { + return s.inmem.Build(ctx, txn, ref) +} + +// NewTransaction will create a new transaction on the underlying inmem store +// but wraps it with a mock Transaction. These are then tracked on the store. +func (s *Store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) { + realTxn, err := s.inmem.NewTransaction(ctx, params...) + if err != nil { + return nil, err + } + txn := &Transaction{ + txn: realTxn, + Committed: 0, + Aborted: 0, + } + s.Transactions = append(s.Transactions, txn) + return txn, nil +} + +// Read will make a read from the underlying inmem store and +// add a new entry to the mock store Reads list. If there +// is an error are the read is unsafe it will be noted in +// the ReadCall. +func (s *Store) Read(ctx context.Context, txn storage.Transaction, path storage.Path) (interface{}, error) { + mockTxn := txn.(*Transaction) + + data, err := s.inmem.Read(ctx, mockTxn.txn, path) + + s.Reads = append(s.Reads, &ReadCall{ + Transaction: mockTxn, + Path: path, + Error: err, + Safe: mockTxn.safeToUse(), + }) + + return data, err +} + +// Write will make a read from the underlying inmem store and +// add a new entry to the mock store Writes list. If there +// is an error are the write is unsafe it will be noted in +// the WriteCall. +func (s *Store) Write(ctx context.Context, txn storage.Transaction, op storage.PatchOp, path storage.Path, value interface{}) error { + mockTxn := txn.(*Transaction) + + err := s.inmem.Write(ctx, mockTxn.txn, op, path, value) + + s.Writes = append(s.Writes, &WriteCall{ + Transaction: mockTxn, + Op: op, + Path: path, + Error: err, + Safe: mockTxn.safeToUse(), + }) + + return nil +} + +// Commit will commit the underlying transaction while +// also updating the mock Transaction +func (s *Store) Commit(ctx context.Context, txn storage.Transaction) error { + mockTxn := txn.(*Transaction) + + err := s.inmem.Commit(ctx, mockTxn.txn) + if err != nil { + return err + } + + mockTxn.Committed++ + return nil +} + +// Abort will abort the underlying transaction while +// also updating the mock Transaction +func (s *Store) Abort(ctx context.Context, txn storage.Transaction) { + mockTxn := txn.(*Transaction) + s.inmem.Abort(ctx, mockTxn.txn) + mockTxn.Aborted++ + return +} diff --git a/rego/rego.go b/rego/rego.go index 04233f4f03..f3beda1003 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -157,13 +157,17 @@ func EvalParsedUnknowns(unknowns []*ast.Term) EvalOption { } } -func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption) (*EvalContext, error) { +// newEvalContext creates a new EvalContext overlaying any EvalOptions over top +// the Rego object on the preparedQuery. The returned function should be called +// once the evaluation is complete to close any transactions that might have +// been opened. +func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption) (*EvalContext, func(context.Context), error) { ectx := &EvalContext{ hasInput: false, rawInput: nil, parsedInput: nil, metrics: pq.r.metrics, - txn: pq.r.txn, + txn: nil, instrument: pq.r.instrument, instrumentation: pq.r.instrumentation, partialNamespace: pq.r.partialNamespace, @@ -181,13 +185,18 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption ectx.instrumentation = topdown.NewInstrumentation(ectx.metrics) } + // Default to an empty "finish" function + finishFunc := func(context.Context) {} + var err error if ectx.txn == nil { ectx.txn, err = pq.r.store.NewTransaction(ctx) if err != nil { - return nil, err + return nil, finishFunc, err + } + finishFunc = func(ctx context.Context) { + pq.r.store.Abort(ctx, ectx.txn) } - defer pq.r.store.Abort(ctx, ectx.txn) } // If we didn't get an input specified in the Eval options @@ -205,11 +214,11 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption } ectx.parsedInput, err = pq.r.parseRawInput(ectx.rawInput, ectx.metrics) if err != nil { - return nil, err + return nil, finishFunc, err } } - return ectx, nil + return ectx, finishFunc, nil } // PreparedEvalQuery holds the prepared Rego state that has been pre-processed @@ -221,11 +230,14 @@ type PreparedEvalQuery struct { // Eval evaluates this PartialResult's Rego object with additional eval options // and returns a ResultSet. // If options are provided they will override the original Rego options respective value. +// The original Rego object transaction will *not* be re-used. A new transaction will be opened +// if one is not provided with an EvalOption. func (pq PreparedEvalQuery) Eval(ctx context.Context, options ...EvalOption) (ResultSet, error) { - ectx, err := pq.newEvalContext(ctx, options) + ectx, finish, err := pq.newEvalContext(ctx, options) if err != nil { return nil, err } + defer finish(ctx) ectx.compiledQuery = pq.r.compiledQueries[evalQueryType] @@ -239,11 +251,14 @@ type PreparedPartialQuery struct { } // Partial runs partial evaluation on the prepared query and returns the result. +// The original Rego object transaction will *not* be re-used. A new transaction will be opened +// if one is not provided with an EvalOption. func (pq PreparedPartialQuery) Partial(ctx context.Context, options ...EvalOption) (*PartialQueries, error) { - ectx, err := pq.newEvalContext(ctx, options) + ectx, finish, err := pq.newEvalContext(ctx, options) if err != nil { return nil, err } + defer finish(ctx) ectx.compiledQuery = pq.r.compiledQueries[partialQueryType] @@ -606,7 +621,7 @@ func (r *Rego) Eval(ctx context.Context) (ResultSet, error) { return nil, err } - return pq.Eval(ctx) + return pq.Eval(ctx, EvalTransaction(r.txn)) } // PartialEval has been deprecated and renamed to PartialResult. @@ -655,7 +670,7 @@ func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) { return nil, err } - return pq.Partial(ctx) + return pq.Partial(ctx, EvalTransaction(r.txn)) } // CompileOption defines a function to set options on Compile calls. diff --git a/rego/rego_test.go b/rego/rego_test.go index f45c230a83..feff7c25ac 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -7,6 +7,7 @@ package rego import ( "context" "encoding/json" + "github.com/open-policy-agent/opa/internal/storage/mock" "reflect" "testing" "time" @@ -439,13 +440,13 @@ func TestPrepareAndEvalNewMetrics(t *testing.T) { } } -func TestPrepareAndEvalNewTransaction(t *testing.T) { +func TestPrepareAndEvalTransaction(t *testing.T) { module := ` package test x = data.foo.y ` ctx := context.Background() - store := inmem.New() + store := mock.New() txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) path, ok := storage.ParsePath("/foo") @@ -475,21 +476,69 @@ func TestPrepareAndEvalNewTransaction(t *testing.T) { t.Fatalf("Unexpected error: %s", err.Error()) } - assertPreparedEvalQueryEval(t, pq, nil, "[[1]]") - store.Commit(ctx, txn) + // Base case, expect it to use the transaction provided + assertPreparedEvalQueryEval(t, pq, []EvalOption{EvalTransaction(txn)}, "[[1]]") + + mockTxn := store.GetTransaction(txn.ID()) + for _, read := range store.Reads { + if read.Transaction != mockTxn { + t.Errorf("Found read operation with an invalid transaction, expected: %d, found: %d", mockTxn.ID(), read.Transaction.ID()) + } + } - // Update the store directly and get a new transaction - newTxn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) - err = store.Write(ctx, newTxn, storage.ReplaceOp, path, map[string]interface{}{"y": 2}) + store.AssertValid(t) + store.Reset() + + // Case with an update to the store and a new transaction + txn = storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + err = store.Write(ctx, txn, storage.AddOp, path, map[string]interface{}{"y": 2}) if err != nil { t.Fatalf("Unexpected error writing to store: %s", err.Error()) } - defer store.Abort(ctx, newTxn) - // Expect that the old transaction and new transaction give - // different results. - assertPreparedEvalQueryEval(t, pq, []EvalOption{EvalTransaction(txn)}, "[[1]]") - assertPreparedEvalQueryEval(t, pq, []EvalOption{EvalTransaction(newTxn)}, "[[2]]") + // Expect the new result from the updated value on this transaction + assertPreparedEvalQueryEval(t, pq, []EvalOption{EvalTransaction(txn)}, "[[2]]") + + err = store.Commit(ctx, txn) + if err != nil { + t.Fatalf("Unexpected error committing to store: %s", err) + } + + newMockTxn := store.GetTransaction(txn.ID()) + for _, read := range store.Reads { + if read.Transaction != newMockTxn { + t.Errorf("Found read operation with an invalid transaction, expected: %d, found: %d", mockTxn.ID(), read.Transaction.ID()) + } + } + + store.AssertValid(t) + store.Reset() + + // Case with no transaction provided, should create a new one and see the latest value + txn = storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + err = store.Write(ctx, txn, storage.AddOp, path, map[string]interface{}{"y": 3}) + if err != nil { + t.Fatalf("Unexpected error writing to store: %s", err.Error()) + } + err = store.Commit(ctx, txn) + if err != nil { + t.Fatalf("Unexpected error committing to store: %s", err) + } + + assertPreparedEvalQueryEval(t, pq, nil, "[[3]]") + + if len(store.Transactions) != 2 { + t.Fatalf("Expected only two transactions on store, found %d", len(store.Transactions)) + } + + autoTxn := store.Transactions[1] + for _, read := range store.Reads { + if read.Transaction != autoTxn { + t.Errorf("Found read operation with an invalid transaction, expected: %d, found: %d", autoTxn, read.Transaction.ID()) + } + } + store.AssertValid(t) + } func TestPrepareAndEvalIdempotent(t *testing.T) {