Skip to content

Commit

Permalink
No errors, only test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
dipankar committed Sep 20, 2024
1 parent 138caa2 commit 489cd43
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
2 changes: 1 addition & 1 deletion stxscript/ast_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class ListComprehension(Expression):
expression: Expression
iterable: Expression
iterator: Identifier
condition: Optional[Expression]
condition: Optional[Expression] = None

@dataclass
class ContractCallExpression(Expression):
Expand Down
60 changes: 47 additions & 13 deletions stxscript/clarity_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,47 @@ def indent(self):
return " " * self.indent_level

def generate(self, node):
# Debug print to check the type of node being processed
if node is None:
return ""
elif isinstance(node, str):
return self.generate_str(node)
elif isinstance(node, int):
return str(node)
elif isinstance(node, float):
return str(node)
elif isinstance(node, bool):
return 'true' if node else 'false'
elif isinstance(node, list):
return self.generate_list(node)
elif isinstance(node, tuple):
return self.generate_tuple(node)
elif isinstance(node, dict):
return self.generate_dict(node)
elif isinstance(node, Identifier):
return self.generate_Identifier(node)
elif isinstance(node, Type):
return self.generate_Type(node)

method = getattr(self, f'generate_{node.__class__.__name__}', None)
if method is None:
raise NotImplementedError(f"Generation not implemented for {node.__class__.__name__}")
return method(node)

def generate_dict(self, node: dict):
return f'(tuple {" ".join(f"({self.generate(k)} {self.generate(v)})" for k, v in node.items())})'

def generate_list(self, node: list):
return f'(list {" ".join(self.generate(item) for item in node)})'

def generate_Program(self, node: Program):
return '\n'.join(self.generate(stmt) for stmt in node.statements)

def generate_FunctionDeclaration(self, node: FunctionDeclaration):
decorators = ' '.join(node.decorators) or 'private'
is_public = any(d.name == '@public' for d in node.decorators)
func_type = 'public' if is_public else 'private'
params = ' '.join(self.generate(param) for param in node.parameters)
return_type = f' {self.generate(node.return_type)}' if node.return_type else ''
body = self.generate(node.body)
return f'(define-{decorators} ({node.name} {params}){return_type}\n{body})'
return f'(define-{func_type} ({node.name} {params})\n{self.indent()}{body})'

def generate_VariableDeclaration(self, node: VariableDeclaration):
type_str = self.generate(node.type) if node.type else ''
Expand All @@ -38,10 +64,9 @@ def generate_MapDeclaration(self, node: MapDeclaration):
value_type = self.generate(node.value_type)
return f'(define-map {node.name} {key_type} {value_type})'

def generate_AssetDeclaration(self, node):
# Here, `node.fields` should contain `Field` objects with `name` and `type` attributes.
fields = ' '.join(f'({self.generate(field.name)} {self.generate(field.type)})' for field in node.fields)
return f'(define-nft {self.generate(node.name)} ({fields}))'
def generate_AssetDeclaration(self, node: AssetDeclaration):
fields = ' '.join(f'({field.name} {self.generate(field.type)})' for field in node.fields)
return f'(define-non-fungible-token {node.name} {fields})'

def generate_TraitDeclaration(self, node: TraitDeclaration):
functions = '\n'.join(self.generate(func) for func in node.functions)
Expand Down Expand Up @@ -144,11 +169,20 @@ def generate_TupleLiteral(self, node: TupleLiteral):
elements = ' '.join(f'({k} {self.generate(v)})' for k, v in node.elements.items())
return f'(tuple {elements})'

def generate_str(self, node: str):
return node

def generate_OptionalLiteral(self, node: OptionalLiteral):
if node.value:
return f'(some {self.generate(node.value)})'
return 'none'

def generate_tuple(self, node):
if isinstance(node, tuple):
return ' '.join(self.generate(item) for item in node)
else:
raise NotImplementedError(f"Generation not implemented for {node.__class__.__name__}")

def generate_PrincipalLiteral(self, node: PrincipalLiteral):
return f"'{node.value}'"

Expand Down Expand Up @@ -176,11 +210,11 @@ def generate_ResponseType(self, node: ResponseType):
def generate_ContractCallExpression(self, node: ContractCallExpression):
contract = self.generate(node.contract)
args = ' '.join(self.generate(arg) for arg in node.arguments)
return f'(contract-call? {contract} {node.function} {args})'
return f'(contract-call? .{contract} {node.function} {args})'

def generate_AssetCallExpression(self, node: AssetCallExpression):
args = ' '.join(self.generate(arg) for arg in node.arguments)
return f'(nft-{node.function} {node.asset} {args})'
return f'(nft-{node.function}? {node.asset} {args})'

def generate_MapExpression(self, node: MapExpression):
list_expr = self.generate(node.list)
Expand All @@ -200,12 +234,12 @@ def generate_FoldExpression(self, node: FoldExpression):

def generate_ListComprehension(self, node: ListComprehension):
expression = self.generate(node.expression)
for_expr = self.generate(node.for_expr)
iterable = self.generate(node.iterable)
iterator = self.generate(node.iterator)
if node.condition:
condition = self.generate(node.condition)
return f'(map {expression} (filter {condition} {for_expr}))'
return f'(map {expression} {for_expr})'
return f'(map {expression} (filter (lambda ({iterator}) {condition}) {iterable}))'
return f'(map (lambda ({iterator}) {expression}) {iterable})'

def generate_ImportDeclaration(self, node: ImportDeclaration):
imports = ' '.join(node.imports)
Expand Down
15 changes: 9 additions & 6 deletions stxscript/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ def map_declaration(self, *args):

def asset_declaration(self, name, *fields):
field_nodes = []
for field in fields:
if isinstance(field, list) and len(field) == 2:
field_name, field_type = field
field_nodes.append(Field(name=field_name, type=field_type))
for i in range(0, len(fields), 2):
if i + 1 < len(fields):
field_name = fields[i]
field_type = fields[i + 1]
if isinstance(field_name, Identifier) and isinstance(field_type, Type):
field_nodes.append(Parameter(name=field_name.name, type=field_type))
else:
raise SyntaxError(f"Unexpected field format: {field_name} {field_type}")
else:
raise SyntaxError(f"Unexpected field format: {field}")
raise SyntaxError(f"Incomplete field definition for {fields[i]}")
return AssetDeclaration(name=name, fields=field_nodes)

def field(self, name, field_type):
Expand Down Expand Up @@ -259,7 +263,6 @@ def transpile(self, input_code):
try:
parse_tree = self.parser.parse(input_code)
ast = self.transformer.transform(parse_tree)
print(ast)
clarity_code = self.generator.generate(ast)
return clarity_code
except Exception as e:
Expand Down

0 comments on commit 489cd43

Please sign in to comment.