diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index 451c4a2def..71709f9b13 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -3666,3 +3666,101 @@ func TestTypeRemovalPragmaUpdates(t *testing.T) { }, ) } + +func TestRuntimeContractUpdateErrorsInOldProgram(t *testing.T) { + + t.Parallel() + + testWithValidatorsAndTypeRemovalEnabled(t, + "invalid #removedType pragma in old code", + func(t *testing.T, config Config) { + + const oldCode = ` + access(all) contract Test { + // invalid type removal pragma in old code + #removedType(R, R2) + access(all) resource R {} + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource R {} + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, config) + + // Should not report any errors for the type invalid removal pragma in the old code. + require.NoError(t, err) + }, + ) + + testWithValidators(t, "invalid old program", func(t *testing.T, config Config) { + + runtime := NewTestInterpreterRuntime() + + var events []cadence.Event + + address := common.MustBytesToAddress([]byte{0x2}) + + location := common.AddressLocation{ + Name: "Test", + Address: address, + } + + const oldCode = ` + access(all) fun main() { + // some lines to increase program length + } + ` + + accountCodes := map[Location][]byte{ + location: []byte(oldCode), + } + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + const newCode = ` + access(all) contract Test {} + ` + + updateTransaction := []byte(newContractUpdateTransaction("Test", newCode)) + + err := runtime.ExecuteTransaction( + Script{ + Source: updateTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + oldProgramError := &stdlib.OldProgramError{} + require.ErrorAs(t, err, &oldProgramError) + }) +} diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index ee5706967b..c5b5a7a8d1 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -71,6 +71,10 @@ func NewCadenceV042ToV1ContractUpdateValidator( var _ UpdateValidator = &CadenceV042ToV1ContractUpdateValidator{} +func (validator *CadenceV042ToV1ContractUpdateValidator) Location() common.Location { + return validator.underlyingUpdateValidator.location +} + func (validator *CadenceV042ToV1ContractUpdateValidator) isTypeRemovalEnabled() bool { return validator.underlyingUpdateValidator.isTypeRemovalEnabled() } @@ -105,7 +109,7 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) getAccountContractNames func (validator *CadenceV042ToV1ContractUpdateValidator) Validate() error { underlyingValidator := validator.underlyingUpdateValidator - oldRootDecl := getRootDeclaration(validator, underlyingValidator.oldProgram) + oldRootDecl := getRootDeclarationOfOldProgram(validator, underlyingValidator.oldProgram, underlyingValidator.newProgram) if underlyingValidator.hasErrors() { return underlyingValidator.getContractUpdateError() } @@ -314,8 +318,8 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) expectedAuthorizationOf return intersectionType.SupportedEntitlements().Access() } -func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRepresentableComposite(decl *ast.CompositeDeclaration) { - dummyNominalType := ast.NewNominalType(nil, decl.Identifier, nil) +func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRepresentableComposite(newDecl *ast.CompositeDeclaration) { + dummyNominalType := ast.NewNominalType(nil, newDecl.Identifier, nil) compositeType := validator.getCompositeType(dummyNominalType) supportedEntitlements := compositeType.SupportedEntitlements() @@ -323,13 +327,13 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRep validator.report(&UnrepresentableEntitlementsUpgrade{ Type: compositeType, InvalidAuthorization: supportedEntitlements.Access(), - Range: decl.Range, + Range: newDecl.Range, }) } } -func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRepresentableInterface(decl *ast.InterfaceDeclaration) { - dummyNominalType := ast.NewNominalType(nil, decl.Identifier, nil) +func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRepresentableInterface(newDecl *ast.InterfaceDeclaration) { + dummyNominalType := ast.NewNominalType(nil, newDecl.Identifier, nil) interfaceType := validator.getInterfaceType(dummyNominalType) supportedEntitlements := interfaceType.SupportedEntitlements() @@ -337,7 +341,7 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) validateEntitlementsRep validator.report(&UnrepresentableEntitlementsUpgrade{ Type: interfaceType, InvalidAuthorization: supportedEntitlements.Access(), - Range: decl.Range, + Range: newDecl.Range, }) } } diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index 92ad1df81e..011b45d274 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -28,11 +28,14 @@ import ( "github.com/onflow/cadence/runtime/errors" ) +const typeRemovalPragmaName = "removedType" + type UpdateValidator interface { ast.TypeEqualityChecker Validate() error report(error) + Location() common.Location getCurrentDeclaration() ast.Declaration setCurrentDeclaration(ast.Declaration) @@ -99,6 +102,10 @@ func NewContractUpdateValidator( } } +func (validator *ContractUpdateValidator) Location() common.Location { + return validator.location +} + func (validator *ContractUpdateValidator) isTypeRemovalEnabled() bool { return validator.typeRemovalEnabled } @@ -122,7 +129,7 @@ func (validator *ContractUpdateValidator) getAccountContractNames(address common // Validate validates the contract update, and returns an error if it is an invalid update. func (validator *ContractUpdateValidator) Validate() error { - oldRootDecl := getRootDeclaration(validator, validator.oldProgram) + oldRootDecl := getRootDeclarationOfOldProgram(validator, validator.oldProgram, validator.newProgram) if validator.hasErrors() { return validator.getContractUpdateError() } @@ -210,6 +217,21 @@ func getRootDeclaration(validator UpdateValidator, program *ast.Program) ast.Dec return decl } +func getRootDeclarationOfOldProgram(validator UpdateValidator, program *ast.Program, position ast.HasPosition) ast.Declaration { + decl, err := getRootDeclarationOfProgram(program) + + if err != nil { + validator.report(&OldProgramError{ + Err: &ContractNotFoundError{ + Range: ast.NewUnmeteredRangeFromPositioned(position), + }, + Location: validator.Location(), + }) + } + + return decl +} + func getRootDeclarationOfProgram(program *ast.Program) (ast.Declaration, error) { compositeDecl := program.SoleContractDeclaration() if compositeDecl != nil { @@ -230,7 +252,11 @@ func (validator *ContractUpdateValidator) hasErrors() bool { return len(validator.errors) > 0 } -func collectRemovedTypePragmas(validator UpdateValidator, pragmas []*ast.PragmaDeclaration) *orderedmap.OrderedMap[string, struct{}] { +func collectRemovedTypePragmas( + validator UpdateValidator, + pragmas []*ast.PragmaDeclaration, + reportErrors bool, +) *orderedmap.OrderedMap[string, struct{}] { removedTypes := orderedmap.New[orderedmap.OrderedMap[string, struct{}]](len(pragmas)) for _, pragma := range pragmas { @@ -238,25 +264,33 @@ func collectRemovedTypePragmas(validator UpdateValidator, pragmas []*ast.PragmaD if !isInvocation { continue } + invokedIdentifier, isIdentifier := invocationExpression.InvokedExpression.(*ast.IdentifierExpression) - if !isIdentifier || invokedIdentifier.Identifier.Identifier != "removedType" { + if !isIdentifier || invokedIdentifier.Identifier.Identifier != typeRemovalPragmaName { continue } + if len(invocationExpression.Arguments) != 1 { - validator.report(&InvalidTypeRemovalPragmaError{ - Expression: pragma.Expression, - Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), - }) + if reportErrors { + validator.report(&InvalidTypeRemovalPragmaError{ + Expression: pragma.Expression, + Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), + }) + } continue } - removedTypeName, isIdentifer := invocationExpression.Arguments[0].Expression.(*ast.IdentifierExpression) - if !isIdentifer { - validator.report(&InvalidTypeRemovalPragmaError{ - Expression: pragma.Expression, - Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), - }) + + removedTypeName, isIdentifier := invocationExpression.Arguments[0].Expression.(*ast.IdentifierExpression) + if !isIdentifier { + if reportErrors { + validator.report(&InvalidTypeRemovalPragmaError{ + Expression: pragma.Expression, + Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), + }) + } continue } + removedTypes.Set(removedTypeName.Identifier.Identifier, struct{}{}) } @@ -282,8 +316,8 @@ func checkDeclarationUpdatability( oldIdentifier := oldDeclaration.DeclarationIdentifier() newIdentifier := newDeclaration.DeclarationIdentifier() - if oldIdentifier.Identifier != newIdentifier.Identifier { + if oldIdentifier.Identifier != newIdentifier.Identifier { validator.report(&NameMismatchError{ OldName: oldIdentifier.Identifier, NewName: newIdentifier.Identifier, @@ -431,8 +465,20 @@ func checkNestedDeclarations( var removedTypes *orderedmap.OrderedMap[string, struct{}] if validator.isTypeRemovalEnabled() { // process pragmas first, as they determine whether types can later be removed - oldRemovedTypes := collectRemovedTypePragmas(validator, oldDeclaration.DeclarationMembers().Pragmas()) - removedTypes = collectRemovedTypePragmas(validator, newDeclaration.DeclarationMembers().Pragmas()) + oldRemovedTypes := collectRemovedTypePragmas( + validator, + oldDeclaration.DeclarationMembers().Pragmas(), + // Do not report errors for pragmas in the old code. + // We are only interested in collecting the pragmas in old code. + // This also avoid reporting mixed errors from both old and new codes. + false, + ) + + removedTypes = collectRemovedTypePragmas( + validator, + newDeclaration.DeclarationMembers().Pragmas(), + true, + ) // #typeRemoval pragmas cannot be removed, so any that appear in the old program must appear in the new program // they can however, be added, so use the new program's type removals for the purposes of checking the upgrade