-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added code parser and tests, currently simply printing the output of …
…code parser (#11) Co-authored-by: Red-Giuliano <[email protected]>
- Loading branch information
1 parent
a8ded80
commit 42a03fc
Showing
4 changed files
with
181 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
fastapi | ||
uvicorn | ||
pydantic | ||
toml | ||
toml | ||
pytest | ||
astroid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import astroid | ||
from typing import List, Dict, Union, Tuple, Any | ||
|
||
def get_imports(module) -> List[str]: | ||
import_froms = [node.names[0][1] or node.names[0][0] for node in module.nodes_of_class(astroid.ImportFrom)] | ||
imports = [node.names[0][1] or node.names[0][0] for node in module.nodes_of_class(astroid.Import)] | ||
return import_froms + imports | ||
|
||
def get_functions(module) -> Tuple[List[str], List[str]]: | ||
function_names = [] | ||
argument_names = [] | ||
for function_def in module.nodes_of_class(astroid.FunctionDef): | ||
function_names.append(function_def.name) | ||
argument_names.extend([arg.name for arg in function_def.args.args]) | ||
return function_names, argument_names | ||
|
||
def get_defined_names(module) -> List[str]: | ||
defined_names = [target.name for defnode in module.nodes_of_class(astroid.Assign) for target in defnode.targets if hasattr(target, 'name')] | ||
func_def_names = [arg.name for func in module.nodes_of_class(astroid.FunctionDef) for arg in func.args.args] | ||
return list(set(defined_names) - set(func_def_names)) | ||
|
||
def get_loaded_modules(module) -> List[str]: | ||
try: | ||
return [node.expr.name for node in module.nodes_of_class(astroid.Attribute) if hasattr(node.expr, 'name')] | ||
except Exception as e: | ||
print(f"Error occurred with modules: {e}") | ||
return [] | ||
|
||
def get_loaded_names(module, defined_names) -> List[str]: | ||
function_names, function_arguments = get_functions(module) | ||
return [usenode.name for usenode in module.nodes_of_class(astroid.Name) if usenode.name not in defined_names + function_arguments] | ||
|
||
def parse_cells(cells: List[str]) -> Dict[int, Dict[str, Union[str, List[str]]]]: | ||
cell_dict = {} | ||
for i, cell in enumerate(cells): | ||
module = astroid.parse(cell) | ||
function_names, function_arguments = get_functions(module) | ||
defined_names = get_defined_names(module) + get_imports(module) + function_names | ||
loaded_names = get_loaded_names(module, defined_names) + get_loaded_modules(module) | ||
cell_dict[i] = { | ||
'code': cell, | ||
'defined_names': defined_names, | ||
'loaded_names': list(set(loaded_names)) # Remove duplicates | ||
} | ||
return cell_dict | ||
|
||
|
||
def build_dependency_graph(cell_dict: Dict[int, Dict[str, Any]]) -> Dict[int, Dict[str, Any]]: | ||
""" | ||
Builds a dependency graph from a cell dictionary. Each node in the graph represents a cell, | ||
and edges represent dependencies between cells based on loaded and defined names. | ||
""" | ||
# Initialize previous components dictionary | ||
prev_components = {} | ||
|
||
# Add child and parent relationships | ||
graph = add_child_cells(cell_dict, prev_components) | ||
|
||
return graph | ||
|
||
def find_child_cells(cell: Dict[str, Any], code_dictionary: Dict[str, Any], idx: int) -> List[str]: | ||
child_cells = [] | ||
names = cell['defined_names'] | ||
for next_key in list(code_dictionary.keys())[idx + 1:]: | ||
next_cell = code_dictionary[next_key] | ||
next_loaded_names = next_cell['loaded_names'] | ||
if set(names).intersection(set(next_loaded_names)): | ||
child_cells.append(next_key) | ||
return child_cells | ||
|
||
def add_parent_cells(code_dictionary: Dict[str, Any]) -> Dict[str, Any]: | ||
for key in list(code_dictionary.keys()): | ||
cell = code_dictionary[key] | ||
child_cells = cell.get('child_cells', []) | ||
for child_cell in child_cells: | ||
code_dictionary[child_cell].setdefault('parent_cells', []).append(key) | ||
cell['child_cells'] = child_cells | ||
return code_dictionary | ||
|
||
def add_child_cells(code_dictionary: Dict[str, Any], prev_components: Dict[str, Any]) -> Dict[str, Any]: | ||
for idx, key in enumerate(list(code_dictionary.keys())): | ||
cell = code_dictionary[key] | ||
cell['child_cells'] = find_child_cells(cell, code_dictionary, idx) | ||
cell['previous_child_cells'] = prev_components.get(key, {}).get('child_cells', []) | ||
return add_parent_cells(code_dictionary) | ||
|
||
|
||
|
||
def print_astroid_tree(code): | ||
module = astroid.parse(code) | ||
print(module.repr_tree()) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from zt_backend.runner.code_cell_parser import parse_cells, build_dependency_graph | ||
import pytest | ||
|
||
# Test 1: No dependencies between cells | ||
def test_no_dependencies(): | ||
cells = ["a = 1", "b = 2", "c = 3"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['a'] and cell_dict[0]['loaded_names'] == [], "Test 1 Failed" | ||
assert cell_dict[1]['defined_names'] == ['b'] and cell_dict[1]['loaded_names'] == [], "Test 1 Failed" | ||
assert cell_dict[2]['defined_names'] == ['c'] and cell_dict[2]['loaded_names'] == [], "Test 1 Failed" | ||
|
||
# Test 2: Simple dependencies | ||
def test_simple_dependencies(): | ||
cells = ["a = 1", "b = a + 2", "c = b + 3"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['a'] and cell_dict[0]['loaded_names'] == [], "Test 2 Failed" | ||
assert cell_dict[1]['defined_names'] == ['b'] and cell_dict[1]['loaded_names'] == ['a'], "Test 2 Failed" | ||
assert cell_dict[2]['defined_names'] == ['c'] and cell_dict[2]['loaded_names'] == ['b'], "Test 2 Failed" | ||
|
||
# Test 3: Complex dependencies | ||
def test_complex_dependencies(): | ||
cells = ["a = 1", "b = a + 2", "c = b + a"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['a'] and cell_dict[0]['loaded_names'] == [], "Test 3 Failed" | ||
assert cell_dict[1]['defined_names'] == ['b'] and cell_dict[1]['loaded_names'] == ['a'], "Test 3 Failed" | ||
assert cell_dict[2]['defined_names'] == ['c'] and set(cell_dict[2]['loaded_names']) == set(['b', 'a']), "Test 3 Failed" | ||
|
||
# Test 4: Overriding dependencies | ||
def test_overriding_dependencies(): | ||
cells = ["a = 1", "b = a + 2", "a = 3", "c = a + b"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['a'] and cell_dict[0]['loaded_names'] == [], "Test 4 Failed" | ||
assert cell_dict[1]['defined_names'] == ['b'] and cell_dict[1]['loaded_names'] == ['a'], "Test 4 Failed" | ||
assert cell_dict[2]['defined_names'] == ['a'] and cell_dict[2]['loaded_names'] == [], "Test 4 Failed" | ||
assert cell_dict[3]['defined_names'] == ['c'] and set(cell_dict[3]['loaded_names']) == set(['a', 'b']), "Test 4 Failed" | ||
|
||
# Test 5: Function definitions and calls | ||
def test_function_definitions_and_calls(): | ||
cells = ["def f(x): return x + 2", "a = f(2)", "b = a + 3"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['f'] and cell_dict[0]['loaded_names'] == [], "Test 5 Failed" | ||
assert cell_dict[1]['defined_names'] == ['a'] and cell_dict[1]['loaded_names'] == ['f'], "Test 5 Failed" | ||
|
||
# Test 6: Dependencies with function calls and redefinitions | ||
def test_dependencies_with_function_calls_and_redefinitions(): | ||
cells = ["def f(x): return x + 2", "a = f(2)", "b = a + 3", "def g(x): return x * 2", "c = g(b)"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['f'] and cell_dict[0]['loaded_names'] == [], "Test 6 Failed" | ||
assert cell_dict[1]['defined_names'] == ['a'] and cell_dict[1]['loaded_names'] == ['f'], "Test 6 Failed" | ||
assert cell_dict[2]['defined_names'] == ['b'] and cell_dict[2]['loaded_names'] == ['a'], "Test 6 Failed" | ||
assert cell_dict[3]['defined_names'] == ['g'] and cell_dict[3]['loaded_names'] == [], "Test 6 Failed" | ||
assert cell_dict[4]['defined_names'] == ['c'] and set(cell_dict[4]['loaded_names']) == set(['g', 'b']), "Test 6 Failed" | ||
|
||
# Test 7: Importing a module in one cell and using it in another | ||
def test_importing_module(): | ||
cells = ["import math", "a = math.sqrt(4)"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['math'] and cell_dict[0]['loaded_names'] == [], "Test 7 Failed" | ||
assert cell_dict[1]['defined_names'] == ['a'] and cell_dict[1]['loaded_names'] == ['math'], "Test 7 Failed" | ||
|
||
# Test 8: Defining multiple variables in one cell | ||
def test_multiple_variables_in_one_cell(): | ||
cells = ["a = 1; b = 2; c = 3", "d = a + b + c"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert set(cell_dict[0]['defined_names']) == set(['a', 'b', 'c']) and cell_dict[0]['loaded_names'] == [], "Test 8 Failed" | ||
assert cell_dict[1]['defined_names'] == ['d'] and set(cell_dict[1]['loaded_names']) == set(['a', 'b', 'c']), "Test 8 Failed" | ||
|
||
# Test 9: Multiline cells | ||
def test_multiline_cells(): | ||
cells = ["a = 1\nb = 2\nc = 3", "d = a + b\ne = c * d"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert set(cell_dict[0]['defined_names']) == set(['a', 'b', 'c']) and cell_dict[0]['loaded_names'] == [], "Test 9 Failed" | ||
assert set(cell_dict[1]['defined_names']) == set(['d', 'e']) and set(cell_dict[1]['loaded_names']) == set(['a', 'b', 'c']), "Test 9 Failed" | ||
|
||
# Test 10: Importing a module with an alias in one cell and using it in another | ||
def test_importing_module_with_alias(): | ||
cells = ["import math as m", "a = m.sqrt(4)"] | ||
cell_dict = build_dependency_graph(parse_cells(cells)) | ||
assert cell_dict[0]['defined_names'] == ['m'] and cell_dict[0]['loaded_names'] == [], "Test 10 Failed" | ||
assert cell_dict[1]['defined_names'] == ['a'] and cell_dict[1]['loaded_names'] == ['m'], "Test 10 Failed" |