Skip to content

Commit

Permalink
can build SetComputeUnitLimit instruction
Browse files Browse the repository at this point in the history
  • Loading branch information
aalu1418 committed Sep 24, 2024
1 parent 46a444d commit 48a3bb8
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 53 deletions.
106 changes: 82 additions & 24 deletions pkg/solana/fees/computebudget.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"fmt"

"github.com/gagliardetto/solana-go"
"golang.org/x/exp/constraints"
)

// https://github.com/solana-labs/solana/blob/60858d043ca612334de300805d93ea3014e8ab37/sdk/src/compute_budget.rs#L25
const (
// deprecated: will not support for building instruction
InstructionRequestUnitsDeprecated uint8 = iota
InstructionRequestUnitsDeprecated computeBudgetInstruction = iota

// Request a specific transaction-wide program heap region size in bytes.
// The value requested must be a multiple of 1024. This new heap region
Expand All @@ -30,29 +31,61 @@ const (
InstructionSetComputeUnitPrice
)

const (
ComputeBudgetProgram = "ComputeBudget111111111111111111111111111111"
var (
ComputeBudgetProgram = solana.MustPublicKeyFromBase58("ComputeBudget111111111111111111111111111111")
)

type computeBudgetInstruction uint8

func (ins computeBudgetInstruction) String() (out string) {
out = "INVALID"
switch ins {
case InstructionRequestUnitsDeprecated:
out = "RequestUnitsDeprecated"
case InstructionRequestHeapFrame:
out = "RequestHeapFrame"
case InstructionSetComputeUnitLimit:
out = "SetComputeUnitLimit"
case InstructionSetComputeUnitPrice:
out = "SetComputeUnitPrice"
}
return out
}

// instruction is an internal interface for encoding instruction data
type instruction interface {
Data() ([]byte, error)
Selector() computeBudgetInstruction
}

// https://docs.solana.com/developing/programming-model/runtime
type ComputeUnitPrice uint64

// returns the compute budget program
func (val ComputeUnitPrice) ProgramID() solana.PublicKey {
return solana.MustPublicKeyFromBase58(ComputeBudgetProgram)
// simple encoding into program expected format
func (val ComputeUnitPrice) Data() ([]byte, error) {
return encode(InstructionSetComputeUnitPrice, val)
}

// No accounts needed
func (val ComputeUnitPrice) Accounts() (accounts []*solana.AccountMeta) {
return accounts
func (val ComputeUnitPrice) Selector() computeBudgetInstruction {
return InstructionSetComputeUnitPrice
}

// simple encoding into program expected format
func (val ComputeUnitPrice) Data() ([]byte, error) {
type ComputeUnitLimit uint32

func (val ComputeUnitLimit) Data() ([]byte, error) {
return encode(InstructionSetComputeUnitLimit, val)
}

func (val ComputeUnitLimit) Selector() computeBudgetInstruction {
return InstructionSetComputeUnitLimit
}

// encode combines the identifier and little encoded value into a byte array
func encode[V constraints.Unsigned](identifier computeBudgetInstruction, val V) ([]byte, error) {
buf := new(bytes.Buffer)

// encode method identifier
if err := buf.WriteByte(InstructionSetComputeUnitPrice); err != nil {
if err := buf.WriteByte(uint8(identifier)); err != nil {
return []byte{}, err
}

Expand All @@ -65,42 +98,64 @@ func (val ComputeUnitPrice) Data() ([]byte, error) {
}

func ParseComputeUnitPrice(data []byte) (ComputeUnitPrice, error) {
if len(data) != (1 + 8) { // instruction byte + uint64
v, err := parse(InstructionSetComputeUnitPrice, data, binary.LittleEndian.Uint64)
return ComputeUnitPrice(v), err

}

func ParseComputeUnitLimit(data []byte) (ComputeUnitLimit, error) {
v, err := parse(InstructionSetComputeUnitLimit, data, binary.LittleEndian.Uint32)
return ComputeUnitLimit(v), err
}

// parse implements tx data parsing for the provided instruction type and specified decoder
func parse[V constraints.Unsigned](ins computeBudgetInstruction, data []byte, decoder func([]byte) V) (V, error) {
if len(data) != (1 + binary.Size(V(0))) { // instruction byte + uintXXX length
return 0, fmt.Errorf("invalid length: %d", len(data))
}

if data[0] != InstructionSetComputeUnitPrice {
return 0, fmt.Errorf("not SetComputeUnitPrice identifier: %d", data[0])
// validate instruction identifier
if data[0] != uint8(ins) {
return 0, fmt.Errorf("not %s identifier: %d", ins, data[0])
}

// guarantees length 8
return ComputeUnitPrice(binary.LittleEndian.Uint64(data[1:])), nil
// guarantees length to fit the binary decoder
return decoder(data[1:]), nil
}

// modifies passed in tx to set compute unit price
func SetComputeUnitPrice(tx *solana.Transaction, price ComputeUnitPrice) error {
func SetComputeUnitPrice(tx *solana.Transaction, value ComputeUnitPrice) error {
return set(tx, value, true) // data feeds expects SetComputeUnitPrice instruction to be right before report instruction
}

func SetComputeUnitLimit(tx *solana.Transaction, value ComputeUnitLimit) error {
return set(tx, value, false) // appends instruction to the end
}

// set adds or modifies instructions for the compute budget program
func set(tx *solana.Transaction, baseData instruction, appendToFront bool) error {
// find ComputeBudget program to accounts if it exists
// reimplements HasAccount to retrieve index: https://github.com/gagliardetto/solana-go/blob/618f56666078f8131a384ab27afd918d248c08b7/message.go#L233
var exists bool
var programIdx uint16
for i, a := range tx.Message.AccountKeys {
if a.Equals(price.ProgramID()) {
if a.Equals(ComputeBudgetProgram) {
exists = true
programIdx = uint16(i)
break
}
}
// if it doesn't exist, add to account keys
if !exists {
tx.Message.AccountKeys = append(tx.Message.AccountKeys, price.ProgramID())
tx.Message.AccountKeys = append(tx.Message.AccountKeys, ComputeBudgetProgram)
programIdx = uint16(len(tx.Message.AccountKeys) - 1) // last index of account keys

// https://github.com/gagliardetto/solana-go/blob/618f56666078f8131a384ab27afd918d248c08b7/transaction.go#L293
tx.Message.Header.NumReadonlyUnsignedAccounts++
}

// get instruction data
data, err := price.Data()
data, err := baseData.Data()
if err != nil {
return err
}
Expand All @@ -117,7 +172,7 @@ func SetComputeUnitPrice(tx *solana.Transaction, price ComputeUnitPrice) error {
for i := range tx.Message.Instructions {
if tx.Message.Instructions[i].ProgramIDIndex == programIdx &&
len(tx.Message.Instructions[i].Data) > 0 &&
tx.Message.Instructions[i].Data[0] == InstructionSetComputeUnitPrice {
tx.Message.Instructions[i].Data[0] == uint8(baseData.Selector()) {
found = true
instructionIdx = i
break
Expand All @@ -127,8 +182,11 @@ func SetComputeUnitPrice(tx *solana.Transaction, price ComputeUnitPrice) error {
if found {
tx.Message.Instructions[instructionIdx] = instruction
} else {
// build with first instruction as set compute unit price
tx.Message.Instructions = append([]solana.CompiledInstruction{instruction}, tx.Message.Instructions...)
if appendToFront {
tx.Message.Instructions = append([]solana.CompiledInstruction{instruction}, tx.Message.Instructions...)
} else {
tx.Message.Instructions = append(tx.Message.Instructions, instruction)
}
}

return nil
Expand Down
110 changes: 82 additions & 28 deletions pkg/solana/fees/computebudget_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fees

import (
"fmt"
"testing"

"github.com/gagliardetto/solana-go"
Expand All @@ -9,11 +10,36 @@ import (
"github.com/stretchr/testify/require"
)

func TestSetComputeUnitPrice(t *testing.T) {
func TestSet(t *testing.T) {
t.Run("ComputeUnitPrice", func(t *testing.T) {
t.Parallel()
testSet(t, func(v uint) ComputeUnitPrice {
return ComputeUnitPrice(v)
}, SetComputeUnitPrice, true)
})
t.Run("ComputeUnitLimit", func(t *testing.T) {
t.Parallel()
testSet(t, func(v uint) ComputeUnitLimit {
return ComputeUnitLimit(v)
}, SetComputeUnitLimit, false)
})

}

func testSet[V instruction](t *testing.T, builder func(uint) V, setter func(*solana.Transaction, V) error, expectFirstInstruction bool) {
key, err := solana.NewRandomPrivateKey()
require.NoError(t, err)

getIndex := func(count int) int {
index := count - 1
if expectFirstInstruction {
index = 0
}
return index
}

t.Run("noAccount_nofee", func(t *testing.T) {
t.Parallel()
// build base tx (no fee)
tx, err := solana.NewTransaction([]solana.Instruction{
system.NewTransferInstruction(
Expand All @@ -26,19 +52,21 @@ func TestSetComputeUnitPrice(t *testing.T) {
instructionCount := len(tx.Message.Instructions)

// add fee
require.NoError(t, SetComputeUnitPrice(tx, 1))
require.NoError(t, setter(tx, builder(1)))

// evaluate
currentCount := len(tx.Message.Instructions)
assert.Greater(t, currentCount, instructionCount)
assert.Equal(t, 2, currentCount)
assert.Equal(t, ComputeBudgetProgram, tx.Message.AccountKeys[tx.Message.Instructions[0].ProgramIDIndex].String())
data, err := ComputeUnitPrice(1).Data()
i := getIndex(currentCount)
assert.Equal(t, ComputeBudgetProgram, tx.Message.AccountKeys[tx.Message.Instructions[i].ProgramIDIndex])
data, err := builder(1).Data()
assert.NoError(t, err)
assert.Equal(t, data, []byte(tx.Message.Instructions[0].Data))
assert.Equal(t, data, []byte(tx.Message.Instructions[i].Data))
})

t.Run("accountExists_noFee", func(t *testing.T) {
t.Parallel()
// build base tx (no fee)
tx, err := solana.NewTransaction([]solana.Instruction{
system.NewTransferInstruction(
Expand All @@ -49,25 +77,27 @@ func TestSetComputeUnitPrice(t *testing.T) {
}, solana.Hash{})
require.NoError(t, err)
accountCount := len(tx.Message.AccountKeys)
tx.Message.AccountKeys = append(tx.Message.AccountKeys, ComputeUnitPrice(0).ProgramID())
tx.Message.AccountKeys = append(tx.Message.AccountKeys, ComputeBudgetProgram)
accountCount++

// add fee
require.NoError(t, SetComputeUnitPrice(tx, 1))
require.NoError(t, setter(tx, builder(1)))

// accounts should not have changed
assert.Equal(t, accountCount, len(tx.Message.AccountKeys))
assert.Equal(t, 2, len(tx.Message.Instructions))
assert.Equal(t, ComputeBudgetProgram, tx.Message.AccountKeys[tx.Message.Instructions[0].ProgramIDIndex].String())
data, err := ComputeUnitPrice(1).Data()
i := getIndex(len(tx.Message.Instructions))
assert.Equal(t, ComputeBudgetProgram, tx.Message.AccountKeys[tx.Message.Instructions[i].ProgramIDIndex])
data, err := builder(1).Data()
assert.NoError(t, err)
assert.Equal(t, data, []byte(tx.Message.Instructions[0].Data))
assert.Equal(t, data, []byte(tx.Message.Instructions[i].Data))
})

// // not a valid test, account must exist for tx to be added
// t.Run("noAccount_feeExists", func(t *testing.T) {})

t.Run("exists_notFirst", func(t *testing.T) {
t.Run("exists_unknownOrder", func(t *testing.T) {
t.Parallel()
// build base tx (no fee)
tx, err := solana.NewTransaction([]solana.Instruction{
system.NewTransferInstruction(
Expand All @@ -80,42 +110,66 @@ func TestSetComputeUnitPrice(t *testing.T) {
transferInstruction := tx.Message.Instructions[0]

// add fee
require.NoError(t, SetComputeUnitPrice(tx, 0))
require.NoError(t, setter(tx, builder(0)))

// swap order of instructions
tx.Message.Instructions[0], tx.Message.Instructions[1] = tx.Message.Instructions[1], tx.Message.Instructions[0]
require.Equal(t, transferInstruction, tx.Message.Instructions[0])
oldFeeInstruction := tx.Message.Instructions[1]

// after swap
computeIndex := 0
transferIndex := 1
if expectFirstInstruction {
computeIndex = 1
transferIndex = 0
}

require.Equal(t, transferInstruction, tx.Message.Instructions[transferIndex])
oldComputeInstruction := tx.Message.Instructions[computeIndex]
accountCount := len(tx.Message.AccountKeys)

// set fee with existing fee instruction
require.NoError(t, SetComputeUnitPrice(tx, 100))
require.Equal(t, transferInstruction, tx.Message.Instructions[0]) // transfer should not have been touched
assert.NotEqual(t, oldFeeInstruction, tx.Message.Instructions[1])
require.NoError(t, setter(tx, builder(100)))
require.Equal(t, transferInstruction, tx.Message.Instructions[transferIndex]) // transfer should not have been touched
assert.NotEqual(t, oldComputeInstruction, tx.Message.Instructions[computeIndex])
assert.Equal(t, accountCount, len(tx.Message.AccountKeys))
assert.Equal(t, 2, len(tx.Message.Instructions)) // instruction count did not change
data, err := ComputeUnitPrice(100).Data()
data, err := builder(100).Data()
assert.NoError(t, err)
assert.Equal(t, data, []byte(tx.Message.Instructions[1].Data))
assert.Equal(t, data, []byte(tx.Message.Instructions[computeIndex].Data))
})
}

func TestParse(t *testing.T) {
t.Run("ComputeUnitPrice", func(t *testing.T) {
t.Parallel()
testParse(t, func(v uint) ComputeUnitPrice {
return ComputeUnitPrice(v)
}, ParseComputeUnitPrice)
})
t.Run("ComputeUnitLimit", func(t *testing.T) {
t.Parallel()
testParse(t, func(v uint) ComputeUnitLimit {
return ComputeUnitLimit(v)
}, ParseComputeUnitLimit)
})
}

func TestParseComputeUnitPrice(t *testing.T) {
data, err := ComputeUnitPrice(100).Data()
func testParse[V instruction](t *testing.T, builder func(uint) V, parser func([]byte) (V, error)) {
data, err := builder(100).Data()
assert.NoError(t, err)

v, err := ParseComputeUnitPrice(data)
v, err := parser(data)
assert.NoError(t, err)
assert.Equal(t, ComputeUnitPrice(100), v)
assert.Equal(t, builder(100), v)

_, err = ParseComputeUnitPrice([]byte{})
_, err = parser([]byte{})
assert.ErrorContains(t, err, "invalid length")
tooLong := [10]byte{}
_, err = ParseComputeUnitPrice(tooLong[:])
_, err = parser(tooLong[:])
assert.ErrorContains(t, err, "invalid length")

invalidData := data
invalidData[0] = InstructionRequestHeapFrame
_, err = ParseComputeUnitPrice(invalidData)
assert.ErrorContains(t, err, "not SetComputeUnitPrice identifier")
invalidData[0] = uint8(InstructionRequestHeapFrame)
_, err = parser(invalidData)
assert.ErrorContains(t, err, fmt.Sprintf("not %s identifier", builder(0).Selector()))
}
2 changes: 1 addition & 1 deletion pkg/solana/fees/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func ParseBlock(res *rpc.GetBlockResult) (out BlockData, err error) {
var price ComputeUnitPrice // default 0
for _, instruction := range baseTx.Message.Instructions {
// find instructions for compute budget program
if baseTx.Message.AccountKeys[instruction.ProgramIDIndex] == solana.MustPublicKeyFromBase58(ComputeBudgetProgram) {
if baseTx.Message.AccountKeys[instruction.ProgramIDIndex] == ComputeBudgetProgram {
parsed, parseErr := ParseComputeUnitPrice(instruction.Data)
// if compute unit price found, break instruction loop
// only one compute unit price tx is allowed
Expand Down

0 comments on commit 48a3bb8

Please sign in to comment.