From 0eed0c09969ed969324b4ac3fce3c47ae784229f Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Sat, 10 Feb 2024 07:42:39 -0600 Subject: [PATCH] ocr3: add context --- .../internal/managed/managed_mercury_oracle.go | 2 +- .../internal/managed/managed_ocr3_oracle.go | 4 ++-- .../internal/mercuryshim/mercuryshims.go | 14 +++++++------- .../internal/ocr3/protocol/outcome_generation.go | 2 +- .../ocr3/protocol/outcome_generation_follower.go | 5 +++-- .../ocr3/protocol/outcome_generation_leader.go | 1 + .../internal/ocr3/protocol/report_attestation.go | 3 ++- .../internal/shim/ocr3_reporting_plugin.go | 16 ++++++++-------- .../ocr3types/counting_mercury_plugin.go | 2 +- .../ocr3types/mercury_plugin.go | 4 ++-- offchainreporting2plus/ocr3types/plugin.go | 10 +++++----- offchainreporting2plus/ocr3types/types.go | 2 +- 12 files changed, 34 insertions(+), 31 deletions(-) diff --git a/offchainreporting2plus/internal/managed/managed_mercury_oracle.go b/offchainreporting2plus/internal/managed/managed_mercury_oracle.go index 9ea5e6e..22fd31e 100644 --- a/offchainreporting2plus/internal/managed/managed_mercury_oracle.go +++ b/offchainreporting2plus/internal/managed/managed_mercury_oracle.go @@ -98,7 +98,7 @@ func RunManagedMercuryOracle( "oid": oid, }) - mercuryPlugin, mercuryPluginInfo, err := mercuryPluginFactory.NewMercuryPlugin(ocr3types.MercuryPluginConfig{ + mercuryPlugin, mercuryPluginInfo, err := mercuryPluginFactory.NewMercuryPlugin(ctx, ocr3types.MercuryPluginConfig{ sharedConfig.ConfigDigest, oid, sharedConfig.N(), diff --git a/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go b/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go index 8e7a98c..d08bd65 100644 --- a/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go +++ b/offchainreporting2plus/internal/managed/managed_ocr3_oracle.go @@ -62,7 +62,7 @@ func RunManagedOCR3Oracle[RI any]( func(ctx context.Context, contractConfig types.ContractConfig, logger loghelper.LoggerWithContext) { skipResourceExhaustionChecks := localConfig.DevelopmentMode == types.EnableDangerousDevelopmentMode - fromAccount, err := contractTransmitter.FromAccount() + fromAccount, err := contractTransmitter.FromAccount(ctx) if err != nil { logger.Error("ManagedOCR3Oracle: error getting FromAccount", commontypes.LogFields{ "error": err, @@ -105,7 +105,7 @@ func RunManagedOCR3Oracle[RI any]( "oid": oid, }) - reportingPlugin, reportingPluginInfo, err := reportingPluginFactory.NewReportingPlugin(ocr3types.ReportingPluginConfig{ + reportingPlugin, reportingPluginInfo, err := reportingPluginFactory.NewReportingPlugin(ctx, ocr3types.ReportingPluginConfig{ sharedConfig.ConfigDigest, oid, sharedConfig.N(), diff --git a/offchainreporting2plus/internal/mercuryshim/mercuryshims.go b/offchainreporting2plus/internal/mercuryshim/mercuryshims.go index 566e718..43547ad 100644 --- a/offchainreporting2plus/internal/mercuryshim/mercuryshims.go +++ b/offchainreporting2plus/internal/mercuryshim/mercuryshims.go @@ -94,8 +94,8 @@ func (t *MercuryOCR3ContractTransmitter) Transmit( ) } -func (t *MercuryOCR3ContractTransmitter) FromAccount() (types.Account, error) { - return t.ocr2ContractTransmitter.FromAccount(context.Background()) +func (t *MercuryOCR3ContractTransmitter) FromAccount(ctx context.Context) (types.Account, error) { + return t.ocr2ContractTransmitter.FromAccount(ctx) } func ocr3MaxOutcomeLength(maxReportLength int) int { @@ -171,22 +171,22 @@ func (p *MercuryReportingPlugin) Observation(ctx context.Context, outctx ocr3typ return observation, nil } -func (p *MercuryReportingPlugin) ValidateObservation(outctx ocr3types.OutcomeContext, query types.Query, ao types.AttributedObservation) error { +func (p *MercuryReportingPlugin) ValidateObservation(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query, ao types.AttributedObservation) error { return nil } -func (p *MercuryReportingPlugin) ObservationQuorum(outctx ocr3types.OutcomeContext, query types.Query) (ocr3types.Quorum, error) { +func (p *MercuryReportingPlugin) ObservationQuorum(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query) (ocr3types.Quorum, error) { return ocr3types.QuorumTwoFPlusOne, nil } -func (p *MercuryReportingPlugin) Outcome(outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { +func (p *MercuryReportingPlugin) Outcome(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { previousOutcomeDeserialized, err := deserializeMercuryReportingPluginOutcome(outctx.PreviousOutcome) if err != nil { return nil, err } //nolint:staticcheck - shouldReport, report, err := p.Plugin.Report(types.ReportTimestamp{p.Config.ConfigDigest, uint32(outctx.Epoch), uint8(outctx.Round)}, previousOutcomeDeserialized.Report, aos) + shouldReport, report, err := p.Plugin.Report(ctx, types.ReportTimestamp{p.Config.ConfigDigest, uint32(outctx.Epoch), uint8(outctx.Round)}, previousOutcomeDeserialized.Report, aos) if err != nil { return nil, err } @@ -204,7 +204,7 @@ func (p *MercuryReportingPlugin) Outcome(outctx ocr3types.OutcomeContext, query return serializeMercuryReportingPluginOutcome(outcomeDeserialized), nil } -func (p *MercuryReportingPlugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[MercuryReportInfo], error) { +func (p *MercuryReportingPlugin) Reports(ctx context.Context, seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[MercuryReportInfo], error) { outcomeDeserialized, err := deserializeMercuryReportingPluginOutcome(outcome) if err != nil { return nil, err diff --git a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation.go b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation.go index 34af222..abc8985 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation.go +++ b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation.go @@ -364,7 +364,7 @@ func (outgen *outcomeGenerationState[RI]) ObservationQuorum(query types.Query) ( 0, // pure function outgen.OutcomeCtx(outgen.sharedState.seqNr), func(ctx context.Context, outctx ocr3types.OutcomeContext) (ocr3types.Quorum, error) { - return outgen.reportingPlugin.ObservationQuorum(outctx, query) + return outgen.reportingPlugin.ObservationQuorum(ctx, outctx, query) }, ) diff --git a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_follower.go b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_follower.go index 9c179c0..3b7e5bb 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_follower.go +++ b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_follower.go @@ -367,6 +367,7 @@ func (outgen *outcomeGenerationState[RI]) tryProcessProposalPool() { outgen.OutcomeCtx(outgen.sharedState.seqNr), func(ctx context.Context, outctx ocr3types.OutcomeContext) (error, error) { return outgen.reportingPlugin.ValidateObservation( + ctx, outctx, *outgen.followerState.query, types.AttributedObservation{aso.SignedObservation.Observation, aso.Observer}, @@ -404,8 +405,8 @@ func (outgen *outcomeGenerationState[RI]) tryProcessProposalPool() { "Outcome", 0, // Outcome is a pure function and should finish "instantly" outgen.OutcomeCtx(outgen.sharedState.seqNr), - func(_ context.Context, outctx ocr3types.OutcomeContext) (ocr3types.Outcome, error) { - return outgen.reportingPlugin.Outcome(outctx, *outgen.followerState.query, attributedObservations) + func(ctx context.Context, outctx ocr3types.OutcomeContext) (ocr3types.Outcome, error) { + return outgen.reportingPlugin.Outcome(ctx, outctx, *outgen.followerState.query, attributedObservations) }, ) if !ok { diff --git a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_leader.go b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_leader.go index 2967abc..750e417 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_leader.go +++ b/offchainreporting2plus/internal/ocr3/protocol/outcome_generation_leader.go @@ -288,6 +288,7 @@ func (outgen *outcomeGenerationState[RI]) messageObservation(msg MessageObservat outgen.OutcomeCtx(outgen.sharedState.seqNr), func(ctx context.Context, outctx ocr3types.OutcomeContext) (error, error) { return outgen.reportingPlugin.ValidateObservation( + ctx, outctx, outgen.leaderState.query, types.AttributedObservation{msg.SignedObservation.Observation, sender}, diff --git a/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go b/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go index 56974f7..15a7c16 100644 --- a/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go +++ b/offchainreporting2plus/internal/ocr3/protocol/report_attestation.go @@ -451,8 +451,9 @@ func (repatt *reportAttestationState[RI]) receivedCertifiedCommit(certifiedCommi commontypes.LogFields{"seqNr": certifiedCommit.SeqNr}, "Reports", 0, // Reports is a pure function and should finish "instantly" - func(context.Context) ([]ocr3types.ReportWithInfo[RI], error) { + func(ctx context.Context) ([]ocr3types.ReportWithInfo[RI], error) { return repatt.reportingPlugin.Reports( + ctx, certifiedCommit.SeqNr, certifiedCommit.Outcome, ) diff --git a/offchainreporting2plus/internal/shim/ocr3_reporting_plugin.go b/offchainreporting2plus/internal/shim/ocr3_reporting_plugin.go index ab02595..6ccf245 100644 --- a/offchainreporting2plus/internal/shim/ocr3_reporting_plugin.go +++ b/offchainreporting2plus/internal/shim/ocr3_reporting_plugin.go @@ -31,8 +31,8 @@ func (rp LimitCheckOCR3ReportingPlugin[RI]) Query(ctx context.Context, outctx oc return query, nil } -func (rp LimitCheckOCR3ReportingPlugin[RI]) ObservationQuorum(outctx ocr3types.OutcomeContext, query types.Query) (ocr3types.Quorum, error) { - return rp.Plugin.ObservationQuorum(outctx, query) +func (rp LimitCheckOCR3ReportingPlugin[RI]) ObservationQuorum(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query) (ocr3types.Quorum, error) { + return rp.Plugin.ObservationQuorum(ctx, outctx, query) } func (rp LimitCheckOCR3ReportingPlugin[RI]) Observation(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query) (types.Observation, error) { @@ -46,12 +46,12 @@ func (rp LimitCheckOCR3ReportingPlugin[RI]) Observation(ctx context.Context, out return observation, nil } -func (rp LimitCheckOCR3ReportingPlugin[RI]) ValidateObservation(outctx ocr3types.OutcomeContext, query types.Query, ao types.AttributedObservation) error { - return rp.Plugin.ValidateObservation(outctx, query, ao) +func (rp LimitCheckOCR3ReportingPlugin[RI]) ValidateObservation(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query, ao types.AttributedObservation) error { + return rp.Plugin.ValidateObservation(ctx, outctx, query, ao) } -func (rp LimitCheckOCR3ReportingPlugin[RI]) Outcome(outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { - outcome, err := rp.Plugin.Outcome(outctx, query, aos) +func (rp LimitCheckOCR3ReportingPlugin[RI]) Outcome(ctx context.Context, outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { + outcome, err := rp.Plugin.Outcome(ctx, outctx, query, aos) if err != nil { return nil, err } @@ -61,8 +61,8 @@ func (rp LimitCheckOCR3ReportingPlugin[RI]) Outcome(outctx ocr3types.OutcomeCont return outcome, nil } -func (rp LimitCheckOCR3ReportingPlugin[RI]) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[RI], error) { - reports, err := rp.Plugin.Reports(seqNr, outcome) +func (rp LimitCheckOCR3ReportingPlugin[RI]) Reports(ctx context.Context, seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[RI], error) { + reports, err := rp.Plugin.Reports(ctx, seqNr, outcome) if err != nil { return nil, err } diff --git a/offchainreporting2plus/ocr3types/counting_mercury_plugin.go b/offchainreporting2plus/ocr3types/counting_mercury_plugin.go index 198e036..648f044 100644 --- a/offchainreporting2plus/ocr3types/counting_mercury_plugin.go +++ b/offchainreporting2plus/ocr3types/counting_mercury_plugin.go @@ -18,7 +18,7 @@ func (p *CountingMercuryPlugin) Observation(ctx context.Context, repts types.Rep return []byte{byte(rand.Int() % 2)}, nil } -func (p *CountingMercuryPlugin) Report(repts types.ReportTimestamp, previousReport types.Report, aos []types.AttributedObservation) (bool, types.Report, error) { +func (p *CountingMercuryPlugin) Report(ctx context.Context, repts types.ReportTimestamp, previousReport types.Report, aos []types.AttributedObservation) (bool, types.Report, error) { report := make([]byte, 4) if len(previousReport) == 0 { if p.initializedReport { diff --git a/offchainreporting2plus/ocr3types/mercury_plugin.go b/offchainreporting2plus/ocr3types/mercury_plugin.go index b786133..199fc98 100644 --- a/offchainreporting2plus/ocr3types/mercury_plugin.go +++ b/offchainreporting2plus/ocr3types/mercury_plugin.go @@ -12,7 +12,7 @@ type MercuryPluginFactory interface { // Creates a new mercury-specific reporting plugin instance. The instance may have // associated goroutines or hold system resources, which should be // released when its Close() function is called. - NewMercuryPlugin(MercuryPluginConfig) (MercuryPlugin, MercuryPluginInfo, error) + NewMercuryPlugin(context.Context, MercuryPluginConfig) (MercuryPlugin, MercuryPluginInfo, error) } type MercuryPluginConfig struct { @@ -127,7 +127,7 @@ type MercuryPlugin interface { // You may assume that the sequence of epochs and the sequence of rounds // within an epoch are monotonically increasing during the lifetime // of an instance of this interface. - Report(repts types.ReportTimestamp, previousReport types.Report, aos []types.AttributedObservation) (bool, types.Report, error) + Report(ctx context.Context, repts types.ReportTimestamp, previousReport types.Report, aos []types.AttributedObservation) (bool, types.Report, error) // If Close is called a second time, it may return an error but must not // panic. This will always be called when a ReportingPlugin is no longer diff --git a/offchainreporting2plus/ocr3types/plugin.go b/offchainreporting2plus/ocr3types/plugin.go index 98ce18b..6cd6f4c 100644 --- a/offchainreporting2plus/ocr3types/plugin.go +++ b/offchainreporting2plus/ocr3types/plugin.go @@ -14,7 +14,7 @@ type ReportingPluginFactory[RI any] interface { // Creates a new reporting plugin instance. The instance may have // associated goroutines or hold system resources, which should be // released when its Close() function is called. - NewReportingPlugin(ReportingPluginConfig) (ReportingPlugin[RI], ReportingPluginInfo, error) + NewReportingPlugin(context.Context, ReportingPluginConfig) (ReportingPlugin[RI], ReportingPluginInfo, error) } type ReportingPluginConfig struct { @@ -197,7 +197,7 @@ type ReportingPlugin[RI any] interface { // *not* strictly) across the lifetime of a protocol instance and that // outctx.previousOutcome contains the consensus outcome with sequence // number (outctx.SeqNr-1). - ValidateObservation(outctx OutcomeContext, query types.Query, ao types.AttributedObservation) error + ValidateObservation(ctx context.Context, outctx OutcomeContext, query types.Query, ao types.AttributedObservation) error // ObservationQuorum returns the minimum number of valid (according to // ValidateObservation) observations needed to construct an outcome. @@ -207,7 +207,7 @@ type ReportingPlugin[RI any] interface { // This is an advanced feature. The "default" approach (what OCR1 & OCR2 // did) is to have an empty ValidateObservation function and return // QuorumTwoFPlusOne from this function. - ObservationQuorum(outctx OutcomeContext, query types.Query) (Quorum, error) + ObservationQuorum(ctx context.Context, outctx OutcomeContext, query types.Query) (Quorum, error) // Generates an outcome for a seqNr, typically based on the previous // outcome, the current query, and the current set of attributed @@ -222,7 +222,7 @@ type ReportingPlugin[RI any] interface { // // You may assume that all provided observations have been validated by // ValidateObservation. - Outcome(outctx OutcomeContext, query types.Query, aos []types.AttributedObservation) (Outcome, error) + Outcome(ctx context.Context, outctx OutcomeContext, query types.Query, aos []types.AttributedObservation) (Outcome, error) // Generates a (possibly empty) list of reports from an outcome. Each report // will be signed and possibly be transmitted to the contract. (Depending on @@ -237,7 +237,7 @@ type ReportingPlugin[RI any] interface { // *not* strictly) across the lifetime of a protocol instance and that // outctx.previousOutcome contains the consensus outcome with sequence // number (outctx.SeqNr-1). - Reports(seqNr uint64, outcome Outcome) ([]ReportWithInfo[RI], error) + Reports(ctx context.Context, seqNr uint64, outcome Outcome) ([]ReportWithInfo[RI], error) // Decides whether a report should be accepted for transmission. Any report // passed to this function will have been attested, i.e. signed by f+1 diff --git a/offchainreporting2plus/ocr3types/types.go b/offchainreporting2plus/ocr3types/types.go index cd1b054..cf85101 100644 --- a/offchainreporting2plus/ocr3types/types.go +++ b/offchainreporting2plus/ocr3types/types.go @@ -27,7 +27,7 @@ type ContractTransmitter[RI any] interface { ) error // Account from which the transmitter invokes the contract - FromAccount() (types.Account, error) + FromAccount(ctx context.Context) (types.Account, error) } // OnchainKeyring provides cryptographic signatures that need to be verifiable