diff --git a/plugins/logs/plugin.go b/plugins/logs/plugin.go index 8b70730afb..9a86aca73a 100644 --- a/plugins/logs/plugin.go +++ b/plugins/logs/plugin.go @@ -279,7 +279,7 @@ func (p *Plugin) Log(ctx context.Context, decision *server.Info) error { return proxy.Log(ctx, event) } - err := p.maskEvent(ctx, &event) + err := p.maskEvent(ctx, decision.Txn, &event) if err != nil { // TODO(tsandall): see note below about error handling. p.logError("Log event masking failed: %v.", err) @@ -443,7 +443,7 @@ func (p *Plugin) bufferChunk(buffer *logBuffer, bs []byte) { } } -func (p *Plugin) maskEvent(ctx context.Context, event *EventV1) error { +func (p *Plugin) maskEvent(ctx context.Context, txn storage.Transaction, event *EventV1) error { err := func() error { @@ -458,6 +458,7 @@ func (p *Plugin) maskEvent(ctx context.Context, event *EventV1) error { rego.ParsedQuery(query), rego.Compiler(p.manager.GetCompiler()), rego.Store(p.manager.Store), + rego.Transaction(txn), rego.Runtime(p.manager.Info), ) diff --git a/plugins/logs/plugin_test.go b/plugins/logs/plugin_test.go index 1e8f219332..6ac405a944 100644 --- a/plugins/logs/plugin_test.go +++ b/plugins/logs/plugin_test.go @@ -483,7 +483,7 @@ func TestPluginMasking(t *testing.T) { event := &EventV1{ Input: &input, } - if err := plugin.maskEvent(ctx, event); err != nil { + if err := plugin.maskEvent(ctx, nil, event); err != nil { t.Fatal(err) } @@ -510,7 +510,7 @@ func TestPluginMasking(t *testing.T) { Input: &input, } - if err := plugin.maskEvent(ctx, event); err != nil { + if err := plugin.maskEvent(ctx, nil, event); err != nil { t.Fatal(err) } @@ -544,7 +544,7 @@ func TestPluginMasking(t *testing.T) { Input: &input, } - if err := plugin.maskEvent(ctx, event); err != nil { + if err := plugin.maskEvent(ctx, nil, event); err != nil { t.Fatal(err) } @@ -570,7 +570,7 @@ func TestPluginMasking(t *testing.T) { Input: &input, } - if err := plugin.maskEvent(ctx, event); err != nil { + if err := plugin.maskEvent(ctx, nil, event); err != nil { t.Fatal(err) } @@ -739,7 +739,7 @@ func BenchmarkMaskingNop(b *testing.B) { b.StartTimer() - if err := plugin.maskEvent(ctx, &event); err != nil { + if err := plugin.maskEvent(ctx, nil, &event); err != nil { b.Fatal(err) } } @@ -792,7 +792,7 @@ func BenchmarkMaskingErase(b *testing.B) { b.StartTimer() - if err := plugin.maskEvent(ctx, &event); err != nil { + if err := plugin.maskEvent(ctx, nil, &event); err != nil { b.Fatal(err) } diff --git a/server/buffer.go b/server/buffer.go index 5bfc115fb7..dbaf2a5231 100644 --- a/server/buffer.go +++ b/server/buffer.go @@ -9,6 +9,7 @@ import ( "time" "github.com/open-policy-agent/opa/metrics" + "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/topdown" ) @@ -72,6 +73,7 @@ func (b *buffer) Iter(fn func(*Info)) { // Info contains information describing a policy decision. type Info struct { + Txn storage.Transaction Revision string DecisionID string RemoteAddr string diff --git a/server/server.go b/server/server.go index da0e9e0df2..ef729df702 100644 --- a/server/server.go +++ b/server/server.go @@ -521,7 +521,7 @@ func (s *Server) initRouter() { s.Handler = router } -func (s *Server) execQuery(ctx context.Context, r *http.Request, decisionID string, parsedQuery ast.Body, input ast.Value, m metrics.Metrics, explainMode types.ExplainModeV1, includeMetrics, includeInstrumentation, pretty bool) (results types.QueryResponseV1, err error) { +func (s *Server) execQuery(ctx context.Context, r *http.Request, txn storage.Transaction, decisionID string, parsedQuery ast.Body, input ast.Value, m metrics.Metrics, explainMode types.ExplainModeV1, includeMetrics, includeInstrumentation, pretty bool) (results types.QueryResponseV1, err error) { diagLogger := s.evalDiagnosticPolicy(r) @@ -549,6 +549,7 @@ func (s *Server) execQuery(ctx context.Context, r *http.Request, decisionID stri rego := rego.New( rego.Store(s.store), + rego.Transaction(txn), rego.Compiler(compiler), rego.ParsedQuery(parsedQuery), rego.ParsedInput(input), @@ -560,7 +561,7 @@ func (s *Server) execQuery(ctx context.Context, r *http.Request, decisionID stri output, err := rego.Eval(ctx) if err != nil { - _ = diagLogger.Log(ctx, decisionID, r.RemoteAddr, "", parsedQuery.String(), rawInput, nil, err, m, buf) + _ = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, "", parsedQuery.String(), rawInput, nil, err, m, buf) return results, err } @@ -577,7 +578,7 @@ func (s *Server) execQuery(ctx context.Context, r *http.Request, decisionID stri } var x interface{} = results.Result - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, "", parsedQuery.String(), rawInput, &x, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, "", parsedQuery.String(), rawInput, &x, nil, m, buf) return results, err } @@ -618,7 +619,15 @@ func (s *Server) indexGet(w http.ResponseWriter, r *http.Request) { _, parsedQuery, _ := validateQuery(qStr) - results, err := s.execQuery(ctx, r, decisionID, parsedQuery, input, nil, explainMode, false, false, true) + txn, err := s.store.NewTransaction(ctx) + if err != nil { + renderQueryResult(w, nil, err, t0) + return + } + + defer s.store.Abort(ctx, txn) + + results, err := s.execQuery(ctx, r, txn, decisionID, parsedQuery, input, nil, explainMode, false, false, true) if err != nil { renderQueryResult(w, nil, err, t0) return @@ -723,28 +732,24 @@ func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, path ast.Re // Handle results. if err != nil { - _ = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) + _ = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) writer.ErrorAuto(w, err) return } if len(rs) == 0 { - // construct error to return to client. err := types.NewErrorV1(types.CodeUndefinedDocument, fmt.Sprintf("%v: %v", types.MsgUndefinedError, path)) - // emit decision log - if logErr := diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf); logErr != nil { - // handle case where decision logging fails + if logErr := diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf); logErr != nil { writer.ErrorAuto(w, logErr) return } - // send normal error back to the client writer.Error(w, 404, err) return } - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, &rs[0].Expressions[0].Value, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, &rs[0].Expressions[0].Value, nil, m, buf) if err != nil { writer.ErrorAuto(w, err) return @@ -1015,7 +1020,7 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { // Handle results. if err != nil { - _ = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) + _ = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) writer.ErrorAuto(w, err) return } @@ -1041,7 +1046,7 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { writer.ErrorAuto(w, err) } } - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, nil, m, buf) if err != nil { writer.ErrorAuto(w, err) return @@ -1056,7 +1061,7 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { result.Explanation = s.getExplainResponse(explainMode, *buf, pretty) } - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, result.Result, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, result.Result, nil, m, buf) if err != nil { writer.ErrorAuto(w, err) return @@ -1184,7 +1189,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { rego, err := s.makeRego(ctx, partial, txn, input, path.String(), m, instrument, buf, opts) if err != nil { - _ = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, nil) + _ = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, nil) writer.ErrorAuto(w, err) return } @@ -1195,7 +1200,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { // Handle results. if err != nil { - _ = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) + _ = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, err, m, buf) writer.ErrorAuto(w, err) return } @@ -1219,7 +1224,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { writer.ErrorAuto(w, err) } } - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, nil, nil, m, buf) if err != nil { writer.ErrorAuto(w, err) return @@ -1234,7 +1239,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { result.Explanation = s.getExplainResponse(explainMode, *buf, pretty) } - err = diagLogger.Log(ctx, decisionID, r.RemoteAddr, path.String(), "", goInput, result.Result, nil, m, buf) + err = diagLogger.Log(ctx, txn, decisionID, r.RemoteAddr, path.String(), "", goInput, result.Result, nil, m, buf) if err != nil { writer.ErrorAuto(w, err) return @@ -1618,7 +1623,15 @@ func (s *Server) v1QueryGet(w http.ResponseWriter, r *http.Request) { includeMetrics := getBoolParam(r.URL, types.ParamMetricsV1, true) includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - results, err := s.execQuery(ctx, r, decisionID, parsedQuery, nil, m, explainMode, includeMetrics, includeInstrumentation, pretty) + txn, err := s.store.NewTransaction(ctx) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + results, err := s.execQuery(ctx, r, txn, decisionID, parsedQuery, nil, m, explainMode, includeMetrics, includeInstrumentation, pretty) if err != nil { switch err := err.(type) { case ast.Errors: @@ -1672,7 +1685,15 @@ func (s *Server) v1QueryPost(w http.ResponseWriter, r *http.Request) { includeMetrics := getBoolParam(r.URL, types.ParamMetricsV1, true) includeInstrumentation := getBoolParam(r.URL, types.ParamInstrumentV1, true) - results, err := s.execQuery(ctx, r, decisionID, parsedQuery, nil, m, explainMode, includeMetrics, includeInstrumentation, pretty) + txn, err := s.store.NewTransaction(ctx) + if err != nil { + writer.ErrorAuto(w, err) + return + } + + defer s.store.Abort(ctx, txn) + + results, err := s.execQuery(ctx, r, txn, decisionID, parsedQuery, nil, m, explainMode, includeMetrics, includeInstrumentation, pretty) if err != nil { switch err := err.(type) { case ast.Errors: @@ -2393,9 +2414,10 @@ func (l diagnosticsLogger) Instrument() bool { return l.instrument } -func (l diagnosticsLogger) Log(ctx context.Context, decisionID, remoteAddr, path string, query string, input *interface{}, results *interface{}, err error, m metrics.Metrics, tracer *topdown.BufferTracer) error { +func (l diagnosticsLogger) Log(ctx context.Context, txn storage.Transaction, decisionID, remoteAddr, path string, query string, input *interface{}, results *interface{}, err error, m metrics.Metrics, tracer *topdown.BufferTracer) error { info := &Info{ + Txn: txn, Revision: l.revision, Timestamp: time.Now().UTC(), DecisionID: decisionID,