Skip to content

Commit

Permalink
strict_scale_decode config flag (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
arjanz authored Nov 15, 2023
1 parent 111ff53 commit e0c548c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
37 changes: 23 additions & 14 deletions substrateinterface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class SubstrateInterface:

def __init__(self, url=None, websocket=None, ss58_format=None, type_registry=None, type_registry_preset=None,
cache_region=None, runtime_config=None, use_remote_preset=False, ws_options=None,
auto_discover=True, auto_reconnect=True):
auto_discover=True, auto_reconnect=True, config=None):
"""
A specialized class in interfacing with a Substrate node.
Expand All @@ -80,8 +80,9 @@ def __init__(self, url=None, websocket=None, ss58_format=None, type_registry=Non
type_registry: A dict containing the custom type registry in format: {'types': {'customType': 'u32'},..}
type_registry_preset: The name of the predefined type registry shipped with the SCALE-codec, e.g. kusama
cache_region: a Dogpile cache region as a central store for the metadata cache
use_remote_preset: When True preset is downloaded from Github master, otherwise use files from local installed scalecodec package
use_remote_preset: When True preset is downloaded from GitHub master, otherwise use files from local installed scalecodec package
ws_options: dict of options to pass to the websocket-client create_connection function
config: dict of config flags to overwrite default configuration
"""

if (not url and not websocket) or (url and websocket):
Expand Down Expand Up @@ -154,9 +155,14 @@ def __init__(self, url=None, websocket=None, ss58_format=None, type_registry=Non
'use_remote_preset': use_remote_preset,
'auto_discover': auto_discover,
'auto_reconnect': auto_reconnect,
'rpc_methods': None
'rpc_methods': None,
'strict_scale_decode': True
}

if type(config) is dict:
self.config.update(config)


# Initialize extension interface
self.extensions = ExtensionInterface(self)

Expand Down Expand Up @@ -1054,7 +1060,7 @@ def result_handler(storage_key, updated_obj, update_nr, subscription_id):
data=ScaleBytes(query_value),
metadata=self.metadata
)
obj.decode()
obj.decode(check_remaining=self.config.get('strict_scale_decode'))
obj.meta_info = {'result_found': response.get('result') is not None}

return obj
Expand All @@ -1080,11 +1086,14 @@ def __query_well_known(self, name: str, block_hash: str) -> ScaleType:
WELL_KNOWN_STORAGE_KEYS[name]['value_type_string']
)
if result:
obj.decode(ScaleBytes(result))
obj.decode(ScaleBytes(result), check_remaining=self.config.get('strict_scale_decode'))
obj.meta_info = {'result_found': True}
return obj
elif WELL_KNOWN_STORAGE_KEYS[name]['default']:
obj.decode(ScaleBytes(WELL_KNOWN_STORAGE_KEYS[name]['default']))
obj.decode(
ScaleBytes(WELL_KNOWN_STORAGE_KEYS[name]['default']),
check_remaining=self.config.get('strict_scale_decode')
)
obj.meta_info = {'result_found': False}
return obj
else:
Expand Down Expand Up @@ -1177,7 +1186,7 @@ def result_handler(message, update_nr, subscription_id):
data=ScaleBytes(change_data),
metadata=self.metadata
)
updated_obj.decode()
updated_obj.decode(check_remaining=self.config.get('strict_scale_decode'))
updated_obj.meta_info = {'result_found': result_found}

subscription_result = subscription_handler(storage_key, updated_obj, update_nr, subscription_id)
Expand Down Expand Up @@ -1212,7 +1221,7 @@ def retrieve_pending_extrinsics(self) -> list:

for extrinsic_data in result_data['result']:
extrinsic = self.runtime_config.create_scale_object('Extrinsic', metadata=self.metadata)
extrinsic.decode(ScaleBytes(extrinsic_data))
extrinsic.decode(ScaleBytes(extrinsic_data), check_remaining=self.config.get('strict_scale_decode'))
extrinsics.append(extrinsic)

return extrinsics
Expand Down Expand Up @@ -1271,7 +1280,7 @@ def runtime_call(self, api: str, method: str, params: Union[list, dict] = None,

# Decode result
result_obj = self.runtime_config.create_scale_object(runtime_call_def['type'])
result_obj.decode(ScaleBytes(result_data['result']))
result_obj.decode(ScaleBytes(result_data['result']), check_remaining=self.config.get('strict_scale_decode'))

return result_obj

Expand Down Expand Up @@ -2314,7 +2323,7 @@ def decode_block(block_data, block_data_hash=None):
runtime_config=self.runtime_config
)
try:
extrinsic_decoder.decode()
extrinsic_decoder.decode(check_remaining=self.config.get('strict_scale_decode'))
block_data['extrinsics'][idx] = extrinsic_decoder

except Exception as e:
Expand All @@ -2332,7 +2341,7 @@ def decode_block(block_data, block_data_hash=None):
raise NotImplementedError("No decoding class found for 'DigestItem'")

log_digest = log_digest_cls(data=ScaleBytes(log_data))
log_digest.decode()
log_digest.decode(check_remaining=self.config.get('strict_scale_decode'))

block_data['header']["digest"]["logs"][idx] = log_digest

Expand All @@ -2350,7 +2359,7 @@ def decode_block(block_data, block_data_hash=None):
data=ScaleBytes(bytes(log_digest[1][1]))
)

babe_predigest.decode()
babe_predigest.decode(check_remaining=self.config.get('strict_scale_decode'))

rank_validator = babe_predigest[1].value['authority_index']

Expand All @@ -2363,7 +2372,7 @@ def decode_block(block_data, block_data_hash=None):
data=ScaleBytes(bytes(log_digest[1][1]))
)

aura_predigest.decode()
aura_predigest.decode(check_remaining=self.config.get('strict_scale_decode'))

rank_validator = aura_predigest.value['slot_number'] % len(validator_set)

Expand Down Expand Up @@ -2666,7 +2675,7 @@ def decode_scale(self, type_string, scale_bytes, block_hash=None, return_scale_o
metadata=self.metadata
)

obj.decode()
obj.decode(check_remaining=self.config.get('strict_scale_decode'))

if return_scale_obj:
return obj
Expand Down
11 changes: 11 additions & 0 deletions test/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import unittest

from scalecodec import ScaleBytes
from scalecodec.exceptions import RemainingScaleBytesNotEmptyException
from substrateinterface import SubstrateInterface
from test import settings

Expand Down Expand Up @@ -99,6 +101,15 @@ def test_context_manager(self):

self.assertFalse(substrate.websocket.connected)

def test_strict_scale_decode(self):

with self.assertRaises(RemainingScaleBytesNotEmptyException):
self.kusama_substrate.decode_scale('u8', ScaleBytes('0x0101'))

with SubstrateInterface(url=settings.KUSAMA_NODE_URL, config={'strict_scale_decode': False}) as substrate:
result = substrate.decode_scale('u8', ScaleBytes('0x0101'))
self.assertEqual(result, 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit e0c548c

Please sign in to comment.