diff --git a/migrations/capcons/capabilities.go b/migrations/capcons/capabilities.go index f33f1e9bd..8b06c9379 100644 --- a/migrations/capcons/capabilities.go +++ b/migrations/capcons/capabilities.go @@ -19,9 +19,13 @@ package capcons import ( + "cmp" "fmt" + "strings" "sync" + "golang.org/x/exp/slices" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" ) @@ -38,7 +42,8 @@ type Path struct { } type AccountCapabilities struct { - Capabilities []AccountCapability + capabilities []AccountCapability + sorted bool } func (c *AccountCapabilities) Record( @@ -47,8 +52,8 @@ func (c *AccountCapabilities) Record( storageKey interpreter.StorageKey, storageMapKey interpreter.StorageMapKey, ) { - c.Capabilities = append( - c.Capabilities, + c.capabilities = append( + c.capabilities, AccountCapability{ TargetPath: path, BorrowType: borrowType, @@ -58,6 +63,43 @@ func (c *AccountCapabilities) Record( }, }, ) + + // Reset the sorted flag, if new entries are added. + c.sorted = false +} + +// ForEachSorted will first sort the capabilities list, +// and iterates through the sorted list. +func (c *AccountCapabilities) ForEachSorted( + f func(AccountCapability) bool, +) { + c.sort() + for _, accountCapability := range c.capabilities { + if !f(accountCapability) { + return + } + } +} + +func (c *AccountCapabilities) sort() { + if c.sorted { + return + } + + slices.SortFunc( + c.capabilities, + func(a, b AccountCapability) int { + pathA := a.TargetPath + pathB := b.TargetPath + + return cmp.Or( + cmp.Compare(pathA.Domain, pathB.Domain), + strings.Compare(pathA.Identifier, pathB.Identifier), + ) + }, + ) + + c.sorted = true } type AccountsCapabilities struct { @@ -97,11 +139,8 @@ func (m *AccountsCapabilities) ForEach( } accountCapabilities := rawAccountCapabilities.(*AccountCapabilities) - for _, accountCapability := range accountCapabilities.Capabilities { - if !f(accountCapability) { - return - } - } + + accountCapabilities.ForEachSorted(f) } func (m *AccountsCapabilities) Get(address common.Address) *AccountCapabilities { diff --git a/migrations/capcons/capabilities_test.go b/migrations/capcons/capabilities_test.go new file mode 100644 index 000000000..b4b486966 --- /dev/null +++ b/migrations/capcons/capabilities_test.go @@ -0,0 +1,122 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package capcons + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" +) + +func TestCapabilitiesIteration(t *testing.T) { + t.Parallel() + + caps := AccountCapabilities{} + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "b"), + nil, + interpreter.StorageKey{}, + nil, + ) + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "a"), + nil, + interpreter.StorageKey{}, + nil, + ) + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "c"), + nil, + interpreter.StorageKey{}, + nil, + ) + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "a"), + nil, + interpreter.StorageKey{}, + nil, + ) + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "b"), + nil, + interpreter.StorageKey{}, + nil, + ) + + require.False(t, caps.sorted) + + var paths []interpreter.PathValue + caps.ForEachSorted(func(capability AccountCapability) bool { + paths = append(paths, capability.TargetPath) + return true + }) + + require.True(t, caps.sorted) + + assert.Equal( + t, + []interpreter.PathValue{ + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "b"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "c"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "b"), + }, + paths, + ) + + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "aa"), + nil, + interpreter.StorageKey{}, + nil, + ) + + require.False(t, caps.sorted) + + paths = make([]interpreter.PathValue, 0) + caps.ForEachSorted(func(capability AccountCapability) bool { + paths = append(paths, capability.TargetPath) + return true + }) + + require.True(t, caps.sorted) + + assert.Equal( + t, + []interpreter.PathValue{ + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "aa"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "b"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "c"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "b"), + }, + paths, + ) +} diff --git a/migrations/capcons/storagecapmigration.go b/migrations/capcons/storagecapmigration.go index ba95c1dd1..4131f7c1e 100644 --- a/migrations/capcons/storagecapmigration.go +++ b/migrations/capcons/storagecapmigration.go @@ -109,7 +109,7 @@ func IssueAccountCapabilities( false, ) - for _, capability := range capabilities.Capabilities { + capabilities.ForEachSorted(func(capability AccountCapability) bool { addressPath := interpreter.AddressPath{ Address: address, @@ -123,7 +123,7 @@ func IssueAccountCapabilities( if hasBorrowType { if _, ok := typedCapabilityMapping.Get(addressPath, capabilityBorrowType.ID()); ok { - continue + return true } borrowType = capabilityBorrowType.(*interpreter.ReferenceStaticType) @@ -141,7 +141,7 @@ func IssueAccountCapabilities( reporter.MissingBorrowType(addressPath, targetPath) if _, _, ok := untypedCapabilityMapping.Get(addressPath); ok { - continue + return true } // If the borrow type is missing, then borrow it as the type of the value. @@ -151,7 +151,7 @@ func IssueAccountCapabilities( // However, if there is no value at the target, //it is not possible to migrate this cap. if value == nil { - continue + return true } valueType := value.StaticType(inter) @@ -198,5 +198,7 @@ func IssueAccountCapabilities( borrowType, capabilityID, ) - } + + return true + }) }