Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manged decimal #172

Open
wants to merge 3 commits into
base: feat/next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions multiversx_sdk/abi/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from multiversx_sdk.abi.fields import Field
from multiversx_sdk.abi.interface import IPayloadHolder
from multiversx_sdk.abi.list_value import ListValue
from multiversx_sdk.abi.managed_decimal_value import ManagedDecimalValue
from multiversx_sdk.abi.multi_value import MultiValue
from multiversx_sdk.abi.option_value import OptionValue
from multiversx_sdk.abi.optional_value import OptionalValue
Expand All @@ -36,6 +37,8 @@
from multiversx_sdk.abi.type_formula_parser import TypeFormulaParser
from multiversx_sdk.abi.variadic_values import VariadicValues

from multiversx_sdk.abi.managed_decimal_signed_value import ManagedDecimalSignedValue


class Abi:
def __init__(self, definition: AbiDefinition) -> None:
Expand Down Expand Up @@ -316,6 +319,29 @@ def _create_prototype(self, type_formula: TypeFormula) -> Any:
return CountedVariadicValues([], item_creator=lambda: self._create_prototype(type_parameter))
if name == "multi":
return MultiValue([self._create_prototype(type_parameter) for type_parameter in type_formula.type_parameters])
if name == "ManagedDecimal":
scale = type_formula.type_parameters[0].name

if scale == "usize":
scale = 0
is_variable = True
else:
scale = int(scale)
is_variable = False

return ManagedDecimalValue(scale=scale, is_variable=is_variable)
if name == "ManagedDecimalSigned":
scale = type_formula.type_parameters[0].name

if scale == "usize":
scale = 0
is_variable = True
else:
scale = int(scale)
is_variable = False

return ManagedDecimalSignedValue(scale=scale, is_variable=is_variable)


# Handle custom types
type_prototype = self._get_custom_type_prototype(name)
Expand Down
38 changes: 37 additions & 1 deletion multiversx_sdk/abi/abi_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from decimal import Decimal
from pathlib import Path
from types import SimpleNamespace
from typing import Optional

from multiversx_sdk.abi.abi import Abi
from multiversx_sdk.abi.abi_definition import ParameterDefinition
from multiversx_sdk.abi.abi_definition import AbiDefinition, ParameterDefinition
from multiversx_sdk.abi.address_value import AddressValue
from multiversx_sdk.abi.biguint_value import BigUIntValue
from multiversx_sdk.abi.bytes_value import BytesValue
Expand All @@ -19,6 +20,8 @@
from multiversx_sdk.abi.variadic_values import VariadicValues
from multiversx_sdk.core.address import Address

from multiversx_sdk.abi.managed_decimal_value import ManagedDecimalValue

testdata = Path(__file__).parent.parent / "testutils" / "testdata"


Expand Down Expand Up @@ -318,3 +321,36 @@ def test_decode_endpoint_output_parameters_multisig_get_pending_action_full_info
Address.from_bech32("erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th").get_public_key(),
Address.from_bech32("erd1spyavw0956vq68xj8y4tenjpq2wd5a9p2c6j8gsz7ztyrnpxrruqzu66jx").get_public_key(),
]


def test_managed_decimals():
abi_definition = AbiDefinition.from_dict({
"endpoints": [{
"name": "foo",
"inputs": [
{
"type": "ManagedDecimal<8>"
},
{
"type": "ManagedDecimal<usize>"
}
],
"outputs": []
}]
})

abi = Abi(abi_definition)
endpoint = abi.endpoints_prototypes_by_name["foo"]

first_input = endpoint.input_parameters[0]
second_input = endpoint.input_parameters[1]

assert isinstance(first_input, ManagedDecimalValue)
assert not first_input.is_variable
assert first_input.scale == 8
assert first_input.value == Decimal(0)

assert isinstance(second_input, ManagedDecimalValue)
assert second_input.is_variable
assert second_input.scale == 0
assert second_input.value == Decimal(0)
1 change: 1 addition & 0 deletions multiversx_sdk/abi/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
INTEGER_MAX_NUM_BYTES = 64
STRUCT_PACKING_FORMAT_FOR_UINT32 = ">I"
ENUM_DISCRIMINANT_FIELD_NAME = "__discriminant__"
U32_SIZE_IN_BYTES = 4
173 changes: 173 additions & 0 deletions multiversx_sdk/abi/localnet_integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from decimal import Decimal
from pathlib import Path

import pytest

from multiversx_sdk.abi.abi import Abi, AbiDefinition
from multiversx_sdk.abi.managed_decimal_value import ManagedDecimalValue
from multiversx_sdk.accounts.account import Account
from multiversx_sdk.network_providers.proxy_network_provider import \
ProxyNetworkProvider
from multiversx_sdk.smart_contracts.smart_contract_controller import \
SmartContractController
from multiversx_sdk.testutils.wallets import load_wallets


@pytest.mark.skip("Requires localnet")
class TestLocalnetInteraction:
wallets = load_wallets()
alice = wallets["alice"]
testdata = Path(__file__).parent.parent / "testutils" / "testdata"

def test_managed_decimal(self):
abi_definition = AbiDefinition.from_dict(
{
"endpoints": [
{
"name": "returns_egld_decimal",
"mutability": "mutable",
"payableInTokens": ["EGLD"],
"inputs": [],
"outputs": [{"type": "ManagedDecimal<18>"}],
},
{
"name": "managed_decimal_addition",
"mutability": "mutable",
"inputs": [
{"name": "first", "type": "ManagedDecimal<2>"},
{"name": "second", "type": "ManagedDecimal<2>"},
],
"outputs": [{"type": "ManagedDecimal<2>"}],
},
{
"name": "managed_decimal_ln",
"mutability": "mutable",
"inputs": [{"name": "x", "type": "ManagedDecimal<9>"}],
"outputs": [{"type": "ManagedDecimalSigned<9>"}],
},
{
"name": "managed_decimal_addition_var",
"mutability": "mutable",
"inputs": [
{"name": "first", "type": "ManagedDecimal<usize>"},
{"name": "second", "type": "ManagedDecimal<usize>"},
],
"outputs": [{"type": "ManagedDecimal<usize>"}],
},
{
"name": "managed_decimal_ln_var",
"mutability": "mutable",
"inputs": [{"name": "x", "type": "ManagedDecimal<usize>"}],
"outputs": [{"type": "ManagedDecimalSigned<9>"}],
},
]
}
)

abi = Abi(abi_definition)

proxy = ProxyNetworkProvider("http://localhost:7950")
sc_controller = SmartContractController(
chain_id="localnet",
network_provider=proxy,
abi=abi,
)

alice = Account(self.alice.secret_key)
alice.nonce = proxy.get_account(alice.address).nonce

# deploy contract
deploy_tx = sc_controller.create_transaction_for_deploy(
sender=alice,
nonce=alice.get_nonce_then_increment(),
bytecode=self.testdata / "basic-features.wasm",
gas_limit=600_000_000,
arguments=[],
)

deploy_tx_hash = proxy.send_transaction(deploy_tx)
deploy_outcome = sc_controller.await_completed_deploy(deploy_tx_hash)
assert deploy_outcome.return_code == "ok"

contract = deploy_outcome.contracts[0].address

# return egld decimals
return_egld_transaction = sc_controller.create_transaction_for_execute(
sender=alice,
nonce=alice.get_nonce_then_increment(),
contract=contract,
gas_limit=100_000_000,
function="returns_egld_decimal",
arguments=[],
native_transfer_amount=1,
)

tx_hash = proxy.send_transaction(return_egld_transaction)
outcome = sc_controller.await_completed_execute(tx_hash)
assert outcome.return_code == "ok"
assert len(outcome.values) == 1
assert outcome.values[0] == Decimal("0.000000000000000001")

# addition with const decimals
addition_transaction = sc_controller.create_transaction_for_execute(
sender=alice,
nonce=alice.get_nonce_then_increment(),
contract=contract,
gas_limit=100_000_000,
function="managed_decimal_addition",
arguments=[ManagedDecimalValue("2.5", 2), ManagedDecimalValue("2.7", 2)],
)

tx_hash = proxy.send_transaction(addition_transaction)
outcome = sc_controller.await_completed_execute(tx_hash)
assert outcome.return_code == "ok"
assert len(outcome.values) == 1
assert outcome.values[0] == Decimal("5.2")

# log
md_ln_transaction = sc_controller.create_transaction_for_execute(
sender=alice,
nonce=alice.get_nonce_then_increment(),
contract=contract,
gas_limit=100_000_000,
function="managed_decimal_ln",
arguments=[ManagedDecimalValue("23", 9)],
)

tx_hash = proxy.send_transaction(md_ln_transaction)
outcome = sc_controller.await_completed_execute(tx_hash)
assert outcome.return_code == "ok"
assert len(outcome.values) == 1
assert outcome.values[0] == Decimal("3.135553845")

# addition var decimals
addition_var_transaction = sc_controller.create_transaction_for_execute(
sender=alice,
nonce=alice.get_nonce_then_increment(),
contract=contract,
gas_limit=50_000_000,
function="managed_decimal_addition_var",
arguments=[ManagedDecimalValue("4", 2, True), ManagedDecimalValue("5", 2, True)],
)

tx_hash = proxy.send_transaction(addition_var_transaction)
outcome = sc_controller.await_completed_execute(tx_hash)
assert outcome.return_code == "ok"
assert len(outcome.values) == 1
assert outcome.values[0] == Decimal("9")

# ln var
ln_var_transaction = sc_controller.create_transaction_for_execute(
sender=alice,
nonce=alice.get_nonce_then_increment(),
contract=contract,
gas_limit=50_000_000,
function="managed_decimal_ln_var",
arguments=[ManagedDecimalValue("23", 9, True)],
)

tx_hash = proxy.send_transaction(ln_var_transaction)
outcome = sc_controller.await_completed_execute(tx_hash)
assert outcome.return_code == "ok"
assert len(outcome.values) == 1
assert outcome.values[0] == Decimal("3.135553845")
103 changes: 103 additions & 0 deletions multiversx_sdk/abi/managed_decimal_signed_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import io
from decimal import ROUND_DOWN, Decimal
from typing import Any, Union

from multiversx_sdk.abi.bigint_value import BigIntValue
from multiversx_sdk.abi.constants import U32_SIZE_IN_BYTES
from multiversx_sdk.abi.shared import read_bytes_exactly
from multiversx_sdk.abi.small_int_values import U32Value


class ManagedDecimalSignedValue:
def __init__(self, value: Union[int, str] = 0, scale: int = 0, is_variable: bool = False):
self.value = Decimal(value)
self.scale = scale
self.is_variable = is_variable

def set_payload(self, value: Any):
if isinstance(value, ManagedDecimalSignedValue):
if self.is_variable != value.is_variable:
raise Exception("Cannot set payload! Both ManagedDecimalValues should be variable.")

self.value = value.value

if self.is_variable:
self.scale = value.scale
else:
self.value = self._convert_to_decimal(value)

def get_payload(self) -> Decimal:
return self.value

def encode_top_level(self, writer: io.BytesIO):
self.encode_nested(writer)

def encode_nested(self, writer: io.BytesIO):
raw_value = BigIntValue(self._convert_value_to_int())
if self.is_variable:
raw_value.encode_nested(writer)
U32Value(self.scale).encode_nested(writer)
else:
raw_value.encode_top_level(writer)

def decode_top_level(self, data: bytes):
if not data:
self.value = Decimal(0)
self.scale = 0
return

bigint = BigIntValue()
scale = U32Value()

if self.is_variable:
# read biguint value length in bytes
big_int_size = self._unsigned_from_bytes(data[:U32_SIZE_IN_BYTES])

# remove biguint length; data is only biguint value and scale
data = data[U32_SIZE_IN_BYTES:]

# read biguint value
bigint.decode_top_level(data[:big_int_size])

# remove biguintvalue; data contains only scale
data = data[big_int_size:]

# read scale
scale.decode_top_level(data)
self.scale = scale.get_payload()
else:
bigint.decode_top_level(data)

self.value = self._convert_to_decimal(bigint.get_payload())

def decode_nested(self, reader: io.BytesIO):
length = self._unsigned_from_bytes(read_bytes_exactly(reader, U32_SIZE_IN_BYTES))
payload = read_bytes_exactly(reader, length)
self.decode_top_level(payload)

def to_string(self) -> str:
value_str = str(self._convert_value_to_int())
if self.scale == 0:
return value_str
if len(value_str) <= self.scale:
# If the value is smaller than the scale, prepend zeros
value_str = "0" * (self.scale - len(value_str) + 1) + value_str
return f"{value_str[:-self.scale]}.{value_str[-self.scale:]}"

def get_precision(self) -> int:
return len(str(self._convert_value_to_int()).lstrip("0"))

def _unsigned_from_bytes(self, data: bytes) -> int:
return int.from_bytes(data, byteorder="big", signed=False)

def _convert_value_to_int(self) -> int:
scaled_value: Decimal = self.value * (10**self.scale)
return int(scaled_value.quantize(Decimal("1."), rounding=ROUND_DOWN))

def _convert_to_decimal(self, value: Union[int, str]) -> Decimal:
return Decimal(value) / (10**self.scale)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ManagedDecimalSignedValue):
return False
return self.value == other.value and self.scale == other.scale
Loading
Loading