diff --git a/stxscript/ast_nodes.py b/stxscript/ast_nodes.py index 54695eb..b74bad3 100644 --- a/stxscript/ast_nodes.py +++ b/stxscript/ast_nodes.py @@ -212,7 +212,7 @@ class ListComprehension(Expression): expression: Expression iterable: Expression iterator: Identifier - condition: Optional[Expression] + condition: Optional[Expression] = None @dataclass class ContractCallExpression(Expression): diff --git a/stxscript/clarity_generator.py b/stxscript/clarity_generator.py index 4d3b055..53f7cc5 100644 --- a/stxscript/clarity_generator.py +++ b/stxscript/clarity_generator.py @@ -8,21 +8,47 @@ def indent(self): return " " * self.indent_level def generate(self, node): - # Debug print to check the type of node being processed + if node is None: + return "" + elif isinstance(node, str): + return self.generate_str(node) + elif isinstance(node, int): + return str(node) + elif isinstance(node, float): + return str(node) + elif isinstance(node, bool): + return 'true' if node else 'false' + elif isinstance(node, list): + return self.generate_list(node) + elif isinstance(node, tuple): + return self.generate_tuple(node) + elif isinstance(node, dict): + return self.generate_dict(node) + elif isinstance(node, Identifier): + return self.generate_Identifier(node) + elif isinstance(node, Type): + return self.generate_Type(node) + method = getattr(self, f'generate_{node.__class__.__name__}', None) if method is None: raise NotImplementedError(f"Generation not implemented for {node.__class__.__name__}") return method(node) + def generate_dict(self, node: dict): + return f'(tuple {" ".join(f"({self.generate(k)} {self.generate(v)})" for k, v in node.items())})' + + def generate_list(self, node: list): + return f'(list {" ".join(self.generate(item) for item in node)})' + def generate_Program(self, node: Program): return '\n'.join(self.generate(stmt) for stmt in node.statements) def generate_FunctionDeclaration(self, node: FunctionDeclaration): - decorators = ' '.join(node.decorators) or 'private' + is_public = any(d.name == '@public' for d in node.decorators) + func_type = 'public' if is_public else 'private' params = ' '.join(self.generate(param) for param in node.parameters) - return_type = f' {self.generate(node.return_type)}' if node.return_type else '' body = self.generate(node.body) - return f'(define-{decorators} ({node.name} {params}){return_type}\n{body})' + return f'(define-{func_type} ({node.name} {params})\n{self.indent()}{body})' def generate_VariableDeclaration(self, node: VariableDeclaration): type_str = self.generate(node.type) if node.type else '' @@ -38,10 +64,9 @@ def generate_MapDeclaration(self, node: MapDeclaration): value_type = self.generate(node.value_type) return f'(define-map {node.name} {key_type} {value_type})' - def generate_AssetDeclaration(self, node): - # Here, `node.fields` should contain `Field` objects with `name` and `type` attributes. - fields = ' '.join(f'({self.generate(field.name)} {self.generate(field.type)})' for field in node.fields) - return f'(define-nft {self.generate(node.name)} ({fields}))' + def generate_AssetDeclaration(self, node: AssetDeclaration): + fields = ' '.join(f'({field.name} {self.generate(field.type)})' for field in node.fields) + return f'(define-non-fungible-token {node.name} {fields})' def generate_TraitDeclaration(self, node: TraitDeclaration): functions = '\n'.join(self.generate(func) for func in node.functions) @@ -144,11 +169,20 @@ def generate_TupleLiteral(self, node: TupleLiteral): elements = ' '.join(f'({k} {self.generate(v)})' for k, v in node.elements.items()) return f'(tuple {elements})' + def generate_str(self, node: str): + return node + def generate_OptionalLiteral(self, node: OptionalLiteral): if node.value: return f'(some {self.generate(node.value)})' return 'none' + def generate_tuple(self, node): + if isinstance(node, tuple): + return ' '.join(self.generate(item) for item in node) + else: + raise NotImplementedError(f"Generation not implemented for {node.__class__.__name__}") + def generate_PrincipalLiteral(self, node: PrincipalLiteral): return f"'{node.value}'" @@ -176,11 +210,11 @@ def generate_ResponseType(self, node: ResponseType): def generate_ContractCallExpression(self, node: ContractCallExpression): contract = self.generate(node.contract) args = ' '.join(self.generate(arg) for arg in node.arguments) - return f'(contract-call? {contract} {node.function} {args})' + return f'(contract-call? .{contract} {node.function} {args})' def generate_AssetCallExpression(self, node: AssetCallExpression): args = ' '.join(self.generate(arg) for arg in node.arguments) - return f'(nft-{node.function} {node.asset} {args})' + return f'(nft-{node.function}? {node.asset} {args})' def generate_MapExpression(self, node: MapExpression): list_expr = self.generate(node.list) @@ -200,12 +234,12 @@ def generate_FoldExpression(self, node: FoldExpression): def generate_ListComprehension(self, node: ListComprehension): expression = self.generate(node.expression) - for_expr = self.generate(node.for_expr) + iterable = self.generate(node.iterable) iterator = self.generate(node.iterator) if node.condition: condition = self.generate(node.condition) - return f'(map {expression} (filter {condition} {for_expr}))' - return f'(map {expression} {for_expr})' + return f'(map {expression} (filter (lambda ({iterator}) {condition}) {iterable}))' + return f'(map (lambda ({iterator}) {expression}) {iterable})' def generate_ImportDeclaration(self, node: ImportDeclaration): imports = ' '.join(node.imports) diff --git a/stxscript/transpiler.py b/stxscript/transpiler.py index 28cf1ae..7e7c79b 100644 --- a/stxscript/transpiler.py +++ b/stxscript/transpiler.py @@ -68,12 +68,16 @@ def map_declaration(self, *args): def asset_declaration(self, name, *fields): field_nodes = [] - for field in fields: - if isinstance(field, list) and len(field) == 2: - field_name, field_type = field - field_nodes.append(Field(name=field_name, type=field_type)) + for i in range(0, len(fields), 2): + if i + 1 < len(fields): + field_name = fields[i] + field_type = fields[i + 1] + if isinstance(field_name, Identifier) and isinstance(field_type, Type): + field_nodes.append(Parameter(name=field_name.name, type=field_type)) + else: + raise SyntaxError(f"Unexpected field format: {field_name} {field_type}") else: - raise SyntaxError(f"Unexpected field format: {field}") + raise SyntaxError(f"Incomplete field definition for {fields[i]}") return AssetDeclaration(name=name, fields=field_nodes) def field(self, name, field_type): @@ -259,7 +263,6 @@ def transpile(self, input_code): try: parse_tree = self.parser.parse(input_code) ast = self.transformer.transform(parse_tree) - print(ast) clarity_code = self.generator.generate(ast) return clarity_code except Exception as e: