Skip to content

Commit

Permalink
#86dtbrzdm - Fix class instantiation causing stack error
Browse files Browse the repository at this point in the history
  • Loading branch information
Mirella de Medeiros committed Jun 4, 2024
1 parent 6c90525 commit 776d9bc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 24 deletions.
18 changes: 12 additions & 6 deletions boa3/internal/compiler/codegenerator/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,19 +522,25 @@ def get_symbol(self, identifier: str, scope: ISymbol | None = None, is_internal:
return found_id, found_symbol
return identifier, Type.none

def initialize_static_fields(self) -> bool:
def initialize_static_fields(self) -> tuple[bool, bool]:
"""
Converts the signature of the method
:return: whether there are static fields to be initialized
:return: whether there are static fields to be initialized and if they can be generated already
"""
can_init_static_fields = False
has_static_fields = False
default_result = (has_static_fields, can_init_static_fields)

if not self.can_init_static_fields:
return False
return default_result
if self.initialized_static_fields:
return False
return default_result

num_static_fields = len(self._statics)
if num_static_fields > 0:
has_static_fields = num_static_fields > 0
can_init_static_fields = True
if has_static_fields:
init_data = bytearray([num_static_fields])
self.__insert1(OpcodeInfo.INITSSLOT, init_data)

Expand All @@ -548,7 +554,7 @@ def initialize_static_fields(self) -> bool:
init_method.init_bytecode = self.last_code
self.symbol_table[constants.INITIALIZE_METHOD_ID] = init_method

return num_static_fields > 0
return has_static_fields, can_init_static_fields

def end_initialize(self):
"""
Expand Down
38 changes: 20 additions & 18 deletions boa3/internal/compiler/codegenerator/codegeneratorvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def visit_Module(self, module: ast.Module) -> GeneratorData:
for stmt in function_stmts:
self.visit(stmt)

if self.generator.initialize_static_fields():
has_static_fields, can_initialize_static_fields = self.generator.initialize_static_fields()
if can_initialize_static_fields:
last_symbols = self.symbols # save to revert in the end and not compromise consequent visits
class_non_static_stmts = []

Expand Down Expand Up @@ -243,23 +244,24 @@ def visit_Module(self, module: ast.Module) -> GeneratorData:
class_non_static_stmts.append(cls_fun)
self.symbols = last_symbols # don't use inner scopes to evaluate the other globals

# to generate the 'initialize' method for Neo
self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function")
self._is_generating_initialize = True
for stmt in global_stmts:
cur_tree = self._tree
cur_filename = self.filename
if hasattr(stmt, 'origin'):
if hasattr(stmt.origin, 'filename'):
self.set_filename(stmt.origin.filename)
self._tree = stmt.origin

self.visit(stmt)
self.filename = cur_filename
self._tree = cur_tree

self._is_generating_initialize = False
self.generator.end_initialize()
if has_static_fields:
# to generate the 'initialize' method for Neo
self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function")
self._is_generating_initialize = True
for stmt in global_stmts:
cur_tree = self._tree
cur_filename = self.filename
if hasattr(stmt, 'origin'):
if hasattr(stmt.origin, 'filename'):
self.set_filename(stmt.origin.filename)
self._tree = stmt.origin

self.visit(stmt)
self.filename = cur_filename
self._tree = cur_tree

self._is_generating_initialize = False
self.generator.end_initialize()

# generate any symbol inside classes that's not variables AFTER generating 'initialize' method
for stmt in class_non_static_stmts:
Expand Down
27 changes: 27 additions & 0 deletions boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any

from boa3.sc.compiletime import public

FOO = "bar"


class MyNFT:
def __init__(self, shape: str, color: str, background: str, size: str) -> None:
self.shape = shape
self.color = color
self.background = background
self.size = size

def export(self) -> dict:
return {
'shape': self.shape,
'color': self.color,
'background': self.background,
'size': self.size
}


@public
def test() -> Any:
nft = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
return nft
16 changes: 16 additions & 0 deletions boa3_test/tests/compiler_tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,19 @@ async def test_return_dict_with_class_attributes(self):
}
result, _ = await self.call('test_pair', [], return_type=dict[str,str])
self.assertEqual(expected_result, result)

async def test_class_init_with_static_variable_no_optimization(self):
await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=False)
from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT

expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
result, _ = await self.call('test', [], return_type=list)
self.assertObjectEqual(expected_result, result)

async def test_class_init_with_static_variable_optimized(self):
await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=True)
from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT

expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
result, _ = await self.call('test', [], return_type=list)
self.assertObjectEqual(expected_result, result)

0 comments on commit 776d9bc

Please sign in to comment.