Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python fix union deserialization #1335

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/iota_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .types.output import *
from .types.output_data import *
from .types.output_id import *
from .types.output_metadata import *
from .types.output_params import *
from .types.payload import *
from .types.send_params import *
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/iota_sdk/client/_high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABCMeta, abstractmethod
from iota_sdk.types.block import Block
from iota_sdk.types.common import CoinType, HexStr, json
from iota_sdk.types.output import OutputWithMetadata
from iota_sdk.types.output_metadata import OutputWithMetadata
from iota_sdk.types.output_id import OutputId


Expand Down
6 changes: 3 additions & 3 deletions bindings/python/iota_sdk/client/_node_core_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright 2023 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union
from typing import List, Optional, Union
from abc import ABCMeta, abstractmethod
from dacite import from_dict

from iota_sdk.types.block import Block, BlockMetadata
from iota_sdk.types.common import HexStr
from iota_sdk.types.node_info import NodeInfo, NodeInfoWrapper
from iota_sdk.types.output import OutputWithMetadata, OutputMetadata
from iota_sdk.types.output_metadata import OutputWithMetadata, OutputMetadata
from iota_sdk.types.output_id import OutputId


Expand Down Expand Up @@ -149,7 +149,7 @@ def get_included_block_metadata(
}))

def call_plugin_route(self, base_plugin_path: str, method: str,
endpoint: str, query_params: [str] = None, request: str = None):
endpoint: str, query_params: Optional[List[str]] = None, request: Optional[str] = None):
"""Extension method which provides request methods for plugins.

Args:
Expand Down
10 changes: 5 additions & 5 deletions bindings/python/iota_sdk/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from iota_sdk.types.feature import Feature
from iota_sdk.types.native_token import NativeToken
from iota_sdk.types.network_info import NetworkInfo
from iota_sdk.types.output import AccountOutput, BasicOutput, FoundryOutput, NftOutput, output_from_dict
from iota_sdk.types.output import AccountOutput, BasicOutput, FoundryOutput, NftOutput, deserialize_output
from iota_sdk.types.payload import Payload, TransactionPayload
from iota_sdk.types.token_scheme import SimpleTokenScheme
from iota_sdk.types.unlock_condition import UnlockCondition
Expand Down Expand Up @@ -197,7 +197,7 @@ def build_account_output(self,
if mana:
mana = str(mana)

return output_from_dict(self._call_method('buildAccountOutput', {
return deserialize_output(self._call_method('buildAccountOutput', {
'accountId': account_id,
'unlockConditions': unlock_conditions,
'amount': amount,
Expand Down Expand Up @@ -245,7 +245,7 @@ def build_basic_output(self,
if mana:
mana = str(mana)

return output_from_dict(self._call_method('buildBasicOutput', {
return deserialize_output(self._call_method('buildBasicOutput', {
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
Expand Down Expand Up @@ -292,7 +292,7 @@ def build_foundry_output(self,
if amount:
amount = str(amount)

return output_from_dict(self._call_method('buildFoundryOutput', {
return deserialize_output(self._call_method('buildFoundryOutput', {
'serialNumber': serial_number,
'tokenScheme': token_scheme.to_dict(),
'unlockConditions': unlock_conditions,
Expand Down Expand Up @@ -344,7 +344,7 @@ def build_nft_output(self,
if mana:
mana = str(mana)

return output_from_dict(self._call_method('buildNftOutput', {
return deserialize_output(self._call_method('buildNftOutput', {
'nftId': nft_id,
'unlockConditions': unlock_conditions,
'amount': amount,
Expand Down
48 changes: 32 additions & 16 deletions bindings/python/iota_sdk/types/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

from enum import IntEnum

from dataclasses import dataclass, field


from typing import Any, Dict, List, TypeAlias, Union
from iota_sdk.types.common import HexStr, json


Expand Down Expand Up @@ -69,19 +67,6 @@ class NFTAddress(Address):
type: int = field(default_factory=lambda: int(AddressType.NFT), init=False)


@json
@dataclass
# pylint: disable=function-redefined
# TODO: Change name
class AccountAddress():
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
"""An Address of the Account.
"""
address: str
key_index: int
internal: bool
used: bool


@json
@dataclass
class AddressWithUnspentOutputs():
Expand All @@ -91,3 +76,34 @@ class AddressWithUnspentOutputs():
key_index: int
internal: bool
output_ids: bool


AddressUnion: TypeAlias = Union[Ed25519Address, AccountAddress, NFTAddress]
DaughterOfMars marked this conversation as resolved.
Show resolved Hide resolved


def deserialize_address(d: Dict[str, Any]) -> AddressUnion:
"""
Takes a dictionary as input and returns an instance of a specific class based on the value of the 'type' key in the dictionary.

Arguments:
* `d`: A dictionary that is expected to have a key called 'type' which specifies the type of the returned value.
"""
address_type = d['type']
if address_type == AddressType.ED25519:
return Ed25519Address.from_dict(d)
if address_type == AddressType.ACCOUNT:
return AccountAddress.from_dict(d)
if address_type == AddressType.NFT:
return NFTAddress.from_dict(d)
raise Exception(f'invalid address type: {address_type}')


def deserialize_addresses(
dicts: List[Dict[str, Any]]) -> List[AddressUnion]:
"""
Takes a list of dictionaries as input and returns a list with specific instances of a classes based on the value of the 'type' key in the dictionary.

Arguments:
* `dicts`: A list of dictionaries that are expected to have a key called 'type' which specifies the type of the returned value.
"""
return list(map(deserialize_address, dicts))
33 changes: 33 additions & 0 deletions bindings/python/iota_sdk/types/context_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Dict, List, TypeAlias, Union
from iota_sdk.types.common import HexStr, json


Expand Down Expand Up @@ -70,3 +71,35 @@ class RewardContextInput(ContextInput):
default_factory=lambda: int(
ContextInputType.Reward),
init=False)


ContextInputUnion: TypeAlias = Union[CommitmentContextInput,
BlockIssuanceCreditContextInput, RewardContextInput]


def deserialize_context_input(d: Dict[str, Any]) -> ContextInputUnion:
"""
Takes a dictionary as input and returns an instance of a specific class based on the value of the 'type' key in the dictionary.

Arguments:
* `d`: A dictionary that is expected to have a key called 'type' which specifies the type of the returned value.
"""
context_input_type = dict['type']
if context_input_type == ContextInputType.Commitment:
return CommitmentContextInput.from_dict(d)
if context_input_type == ContextInputType.BlockIssuanceCredit:
return BlockIssuanceCreditContextInput.from_dict(d)
if context_input_type == ContextInputType.Reward:
return RewardContextInput.from_dict(d)
raise Exception(f'invalid context input type: {context_input_type}')


def deserialize_context_inputs(
dicts: List[Dict[str, Any]]) -> List[ContextInputUnion]:
"""
Takes a list of dictionaries as input and returns a list with specific instances of a classes based on the value of the 'type' key in the dictionary.

Arguments:
* `dicts`: A list of dictionaries that are expected to have a key called 'type' which specifies the type of the returned value.
"""
return list(map(deserialize_context_input, dicts))
14 changes: 5 additions & 9 deletions bindings/python/iota_sdk/types/essence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@

from __future__ import annotations
from enum import IntEnum
from typing import TYPE_CHECKING, Optional, List, Union
from typing import TYPE_CHECKING, Optional, List

from dataclasses import dataclass, field

from iota_sdk.types.common import HexStr, json, SlotIndex
from iota_sdk.types.mana import ManaAllotment
# TODO: Add missing output types in #1174
# pylint: disable=no-name-in-module
from iota_sdk.types.output import BasicOutput, AccountOutput, FoundryOutput, NftOutput, DelegationOutput
from iota_sdk.types.input import UtxoInput
from iota_sdk.types.context_input import CommitmentContextInput, BlockIssuanceCreditContextInput, RewardContextInput
from iota_sdk.types.context_input import ContextInputUnion
from iota_sdk.types.output import OutputUnion

# Required to prevent circular import
if TYPE_CHECKING:
Expand Down Expand Up @@ -57,10 +55,8 @@ class RegularTransactionEssence(TransactionEssence):
creation_slot: SlotIndex
inputs: List[UtxoInput]
inputs_commitment: HexStr
outputs: List[Union[BasicOutput, AccountOutput,
FoundryOutput, NftOutput, DelegationOutput]]
context_inputs: Optional[List[Union[CommitmentContextInput,
BlockIssuanceCreditContextInput, RewardContextInput]]] = None
outputs: List[OutputUnion]
context_inputs: Optional[List[ContextInputUnion]] = None
allotments: Optional[List[ManaAllotment]] = None
payload: Optional[Payload] = None
type: int = field(
Expand Down
55 changes: 49 additions & 6 deletions bindings/python/iota_sdk/types/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

from enum import IntEnum
from typing import List, Union
from typing import Dict, List, TypeAlias, Union, Any
from dataclasses import dataclass, field

from iota_sdk.types.address import Ed25519Address, AccountAddress, NFTAddress
from dataclasses_json import config
from iota_sdk.types.address import AddressUnion, deserialize_address
from iota_sdk.types.common import EpochIndex, HexStr, json, SlotIndex


Expand Down Expand Up @@ -43,7 +43,10 @@ class SenderFeature(Feature):
Attributes:
address: A given sender address.
"""
address: Union[Ed25519Address, AccountAddress, NFTAddress]
address: AddressUnion = field(
metadata=config(
decoder=deserialize_address
))
type: int = field(
default_factory=lambda: int(
FeatureType.Sender),
Expand All @@ -57,7 +60,10 @@ class IssuerFeature(Feature):
Attributes:
address: A given issuer address.
"""
address: Union[Ed25519Address, AccountAddress, NFTAddress]
address: AddressUnion = field(
metadata=config(
decoder=deserialize_address
))
type: int = field(
default_factory=lambda: int(
FeatureType.Issuer),
Expand Down Expand Up @@ -91,7 +97,7 @@ class TagFeature(Feature):

@json
@dataclass
class BlockIssuer(Feature):
class BlockIssuerFeature(Feature):
"""Contains the public keys to verify block signatures and allows for unbonding the issuer deposit.
Attributes:
expiry_slot: The slot index at which the Block Issuer Feature expires and can be removed.
Expand Down Expand Up @@ -124,3 +130,40 @@ class StakingFeature(Feature):
default_factory=lambda: int(
FeatureType.Staking),
init=False)


FeatureUnion: TypeAlias = Union[SenderFeature, IssuerFeature,
MetadataFeature, TagFeature, BlockIssuerFeature, StakingFeature]


def deserialize_feature(d: Dict[str, Any]) -> FeatureUnion:
"""
Takes a dictionary as input and returns an instance of a specific class based on the value of the 'type' key in the dictionary.

Arguments:
* `d`: A dictionary that is expected to have a key called 'type' which specifies the type of the returned value.
"""
feature_type = d['type']
if feature_type == FeatureType.Sender:
return SenderFeature.from_dict(d)
if feature_type == FeatureType.Issuer:
return IssuerFeature.from_dict(d)
if feature_type == FeatureType.Metadata:
return MetadataFeature.from_dict(d)
if feature_type == FeatureType.Tag:
return TagFeature.from_dict(d)
if feature_type == FeatureType.BlockIssuer:
return BlockIssuerFeature.from_dict(d)
if feature_type == FeatureType.Staking:
return StakingFeature.from_dict(d)
raise Exception(f'invalid feature type: {feature_type}')


def deserialize_features(dicts: List[Dict[str, Any]]) -> List[FeatureUnion]:
"""
Takes a list of dictionaries as input and returns a list with specific instances of a classes based on the value of the 'type' key in the dictionary.

Arguments:
* `dicts`: A list of dictionaries that are expected to have a key called 'type' which specifies the type of the returned value.
"""
return list(map(deserialize_feature, dicts))
Loading
Loading