Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated getSeedBytes to handle array seeds #1007

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions integration-tests/relayinterface/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func TestAccountLookups(t *testing.T) {

func TestPDALookups(t *testing.T) {
programID := chainwriter.GetRandomPubKey(t)
ctx := tests.Context(t)

t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) {
seed := chainwriter.GetRandomPubKey(t)
Expand All @@ -152,7 +153,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
result, err := pdaLookup.Resolve(ctx, nil, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
Expand Down Expand Up @@ -183,7 +183,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
Expand All @@ -205,7 +204,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": []byte("data"),
}
Expand Down Expand Up @@ -241,7 +239,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
Expand All @@ -251,6 +248,84 @@ func TestPDALookups(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})

t.Run("PDALookups resolves list of PDAs when a seed is an array", func(t *testing.T) {
singleSeed := []byte("test_seed")
arraySeed := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}

expectedMeta := []*solana.AccountMeta{}

for _, seed := range arraySeed {
pda, _, err := solana.FindProgramAddress([][]byte{singleSeed, seed.Bytes()}, programID)
require.NoError(t, err)
meta := &solana.AccountMeta{
PublicKey: pda,
IsSigner: false,
IsWritable: false,
}
expectedMeta = append(expectedMeta, meta)
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Seed{
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "single_seed"}},
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "array_seed"}},
},
IsSigner: false,
IsWritable: false,
}

args := map[string]interface{}{
"single_seed": singleSeed,
"array_seed": arraySeed,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})

t.Run("PDALookups resolves list of PDAs when multiple seeds are arrays", func(t *testing.T) {
arraySeed1 := [][]byte{[]byte("test_seed1"), []byte("test_seed2")}
arraySeed2 := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}

expectedMeta := []*solana.AccountMeta{}

for _, seed1 := range arraySeed1 {
for _, seed2 := range arraySeed2 {
pda, _, err := solana.FindProgramAddress([][]byte{seed1, seed2.Bytes()}, programID)
require.NoError(t, err)
meta := &solana.AccountMeta{
PublicKey: pda,
IsSigner: false,
IsWritable: false,
}
expectedMeta = append(expectedMeta, meta)
}
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Seed{
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
},
IsSigner: false,
IsWritable: false,
}

args := map[string]interface{}{
"seed1": arraySeed1,
"seed2": arraySeed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
}

func TestLookupTables(t *testing.T) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/solana/chainwriter/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ func errorWithDebugID(err error, debugID string) error {
// traversePath recursively traverses the given structure based on the provided path.
func traversePath(data any, path []string) ([]any, error) {
if len(path) == 0 {
val := reflect.ValueOf(data)

// If the final data is a slice or array, flatten it into multiple items,
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
// don't flatten []byte
prashantkumar1982 marked this conversation as resolved.
Show resolved Hide resolved
if val.Type().Elem().Kind() == reflect.Uint8 {
return []any{val.Interface()}, nil
}

var results []any
for i := 0; i < val.Len(); i++ {
results = append(results, val.Index(i).Interface())
}
return results, nil
}
// Otherwise, return just this one item
return []any{data}, nil
}

Expand Down Expand Up @@ -124,9 +140,6 @@ func traversePath(data any, path []string) ([]any, error) {
}
return traversePath(value.Interface(), path[1:])
default:
if len(path) == 1 && val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 {
return []any{val.Interface()}, nil
}
return nil, errors.New("unexpected type encountered at path: " + path[0])
}
}
Expand Down
100 changes: 70 additions & 30 deletions pkg/solana/chainwriter/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map
return nil, fmt.Errorf("error getting public key for PDALookups: %w", err)
}

seeds, err := getSeedBytes(ctx, pda, args, derivedTableMap, reader)
seeds, err := getSeedBytesCombinations(ctx, pda, args, derivedTableMap, reader)
if err != nil {
return nil, fmt.Errorf("error getting seeds for PDALookups: %w", err)
}
Expand Down Expand Up @@ -209,64 +209,104 @@ func decodeBorshIntoType(data []byte, typ reflect.Type) (interface{}, error) {
return reflect.ValueOf(instance).Elem().Interface(), nil
}

// getSeedBytes extracts the seeds for the PDALookups.
// It handles both AddressSeeds (which are public keys) and ValueSeeds (which are byte arrays from input args).
func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([][]byte, error) {
var seedBytes [][]byte
// getSeedBytesCombinations extracts the seeds for the PDALookups.
// The return type is [][][]byte, where each element of the outer slice is
// one combination of seeds. This handles the case where one seed can resolve
// to multiple addresses, multiplying the combinations accordingly.
func getSeedBytesCombinations(
ctx context.Context,
lookup PDALookups,
args any,
derivedTableMap map[string]map[string][]*solana.AccountMeta,
reader client.Reader,
) ([][][]byte, error) {
allCombinations := [][][]byte{
{},
}

// For each seed in the definition, expand the current list of combinations
// by all possible values for this seed.
for _, seed := range lookup.Seeds {
expansions := make([][]byte, 0)
if seed.Static != nil {
seedBytes = append(seedBytes, seed.Static)
}
if seed.Dynamic != nil {
expansions = append(expansions, seed.Static)
// Static and Dynamic seeds are mutually exclusive
} else if seed.Dynamic != nil {
dynamicSeed := seed.Dynamic
if lookupSeed, ok := dynamicSeed.(AccountLookup); ok {
// Get value from a location (This doens't have to be an address, it can be any value)
bytes, err := GetValuesAtLocation(args, lookupSeed.Location)
if err != nil {
return nil, fmt.Errorf("error getting address seed: %w", err)
return nil, fmt.Errorf("error getting address seed for location %q: %w", lookupSeed.Location, err)
}
// validate seed length
// append each byte array to the expansions
for _, b := range bytes {
// validate seed length
if len(b) > solana.MaxSeedLength {
return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b))
}
seedBytes = append(seedBytes, b)
expansions = append(expansions, b)
}
} else {
// Get address seeds from the lookup
seedAddresses, err := GetAddresses(ctx, args, []Lookup{dynamicSeed}, derivedTableMap, reader)
if err != nil {
return nil, fmt.Errorf("error getting address seed: %w", err)
}

// Add each address seed as bytes
for _, address := range seedAddresses {
seedBytes = append(seedBytes, address.PublicKey.Bytes())
// Add each address seed to the expansions
for _, addrMeta := range seedAddresses {
b := addrMeta.PublicKey.Bytes()
if len(b) > solana.MaxSeedLength {
return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b))
}
expansions = append(expansions, b)
}
}
}

// expansions is the list of possible seed bytes for this single seed lookup.
// Multiply the existing combinations in allCombinations by each item in expansions.
newCombinations := make([][][]byte, 0, len(allCombinations)*len(expansions))
for _, existingCombo := range allCombinations {
for _, expandedSeed := range expansions {
comboCopy := make([][]byte, len(existingCombo)+1)
copy(comboCopy, existingCombo)
comboCopy[len(existingCombo)] = expandedSeed
newCombinations = append(newCombinations, comboCopy)
}
}

allCombinations = newCombinations
}

return seedBytes, nil
return allCombinations, nil
}

// generatePDAs generates program-derived addresses (PDAs) from public keys and seeds.
func generatePDAs(publicKeys []*solana.AccountMeta, seeds [][]byte, lookup PDALookups) ([]*solana.AccountMeta, error) {
if len(seeds) > solana.MaxSeeds {
return nil, fmt.Errorf("seed maximum exceeded: %d", len(seeds))
}
var addresses []*solana.AccountMeta
// it will result in a list of PDAs whose length is the product of the number of public keys
// and the number of seed combinations.
func generatePDAs(
publicKeys []*solana.AccountMeta,
seedCombos [][][]byte,
lookup PDALookups,
) ([]*solana.AccountMeta, error) {
var results []*solana.AccountMeta
for _, publicKeyMeta := range publicKeys {
address, _, err := solana.FindProgramAddress(seeds, publicKeyMeta.PublicKey)
if err != nil {
return nil, fmt.Errorf("error finding program address: %w", err)
for _, combo := range seedCombos {
if len(combo) > solana.MaxSeeds {
return nil, fmt.Errorf("seed maximum exceeded: %d", len(combo))
}
address, _, err := solana.FindProgramAddress(combo, publicKeyMeta.PublicKey)
if err != nil {
return nil, fmt.Errorf("error finding program address: %w", err)
}
results = append(results, &solana.AccountMeta{
PublicKey: address,
IsSigner: lookup.IsSigner,
IsWritable: lookup.IsWritable,
})
}
addresses = append(addresses, &solana.AccountMeta{
PublicKey: address,
IsSigner: lookup.IsSigner,
IsWritable: lookup.IsWritable,
})
}
return addresses, nil

return results, nil
}
Loading
Loading