diff --git a/runtime/program_params_validation_test.go b/runtime/program_params_validation_test.go index ba746b9aa..a091d5818 100644 --- a/runtime/program_params_validation_test.go +++ b/runtime/program_params_validation_test.go @@ -1391,4 +1391,95 @@ func TestRuntimeTransactionParameterTypeValidation(t *testing.T) { expectRuntimeError(t, err, &ArgumentNotImportableError{}) }) + newEnumType := func() cadence.Enum { + return cadence.NewEnum([]cadence.Value{ + cadence.NewInt(0), + }).WithType(cadence.NewEnumType( + common.AddressLocation{ + Address: common.MustBytesToAddress([]byte{0x1}), + Name: "C", + }, + "C.Alpha", + cadence.IntType, + []cadence.Field{ + { + Identifier: sema.EnumRawValueFieldName, + Type: cadence.IntType, + }, + }, + nil, + )) + } + + t.Run("Enum Optional Type", func(t *testing.T) { + t.Parallel() + + contracts := map[common.AddressLocation][]byte{ + { + Address: common.MustBytesToAddress([]byte{0x1}), + Name: "C", + }: []byte(` + access(all) contract C { + access(all) + enum Alpha: Int { + access(all) + case A + + access(all) + case B + } + } + `), + } + + script := ` + import C from 0x1 + + transaction(arg: C.Alpha?) {} + ` + + err := executeTransaction(t, script, contracts, cadence.NewOptional(nil)) + assert.NoError(t, err) + }) + + t.Run("Enum Type", func(t *testing.T) { + t.Parallel() + + contracts := map[common.AddressLocation][]byte{ + { + Address: common.MustBytesToAddress([]byte{0x1}), + Name: "C", + }: []byte(` + access(all) contract C { + access(all) + enum Alpha: Int { + access(all) + case A + + access(all) + case B + } + } + `), + } + + script := ` + import C from 0x1 + + transaction(arg: C.Alpha) { + execute { + let values: [AnyStruct] = [] + values.append(arg) + if arg == C.Alpha.A { + values.append(C.Alpha.B) + } + assert(values.length == 2) + } + } + ` + + err := executeTransaction(t, script, contracts, newEnumType()) + assert.NoError(t, err) + }) + } diff --git a/runtime/tests/interpreter/transactions_test.go b/runtime/tests/interpreter/transactions_test.go index a5b518fed..91840748e 100644 --- a/runtime/tests/interpreter/transactions_test.go +++ b/runtime/tests/interpreter/transactions_test.go @@ -320,6 +320,76 @@ func TestInterpretTransactions(t *testing.T) { ArrayElements(inter, values.(*interpreter.ArrayValue)), ) }) + + t.Run("Enum", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + access(all) + enum Alpha: Int { + access(all) + case A + + access(all) + case B + } + + let a = Alpha.A + let b = Alpha.B + + let values: [AnyStruct] = [] + + transaction(x: Alpha) { + + prepare(signer: &Account) { + values.append(signer.address) + values.append(x) + if x == Alpha.A { + values.append(Alpha.B) + } else { + values.append(-1) + } + } + } + `) + + arguments := []interpreter.Value{ + inter.Globals.Get("a").GetValue(inter), + } + + address := common.MustBytesToAddress([]byte{0x1}) + + account := stdlib.NewAccountReferenceValue( + nil, + nil, + interpreter.AddressValue(address), + interpreter.UnauthorizedAccess, + interpreter.EmptyLocationRange, + ) + + prepareArguments := []interpreter.Value{account} + + arguments = append(arguments, prepareArguments...) + + err := inter.InvokeTransaction(0, arguments...) + require.NoError(t, err) + + values := inter.Globals.Get("values").GetValue(inter) + + require.IsType(t, &interpreter.ArrayValue{}, values) + + AssertValueSlicesEqual( + t, + inter, + []interpreter.Value{ + interpreter.AddressValue(address), + inter.Globals.Get("a").GetValue(inter), + inter.Globals.Get("b").GetValue(inter), + }, + ArrayElements(inter, values.(*interpreter.ArrayValue)), + ) + }) } func TestRuntimeInvalidTransferInExecute(t *testing.T) {