Skip to content

Commit

Permalink
add option to disable globals tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Aug 21, 2024
1 parent 71a652e commit 7267001
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
14 changes: 10 additions & 4 deletions mandala/deps/tracers/dec_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __getitem__(self, __key: str) -> Any:
f"Function/class {result} from module {unwrapped_result.__module__} is accessed but not tracked"
)
elif is_global_val(result):
TracerState.tracer.register_global_access(key=__key, value=result)
tracer.register_global_access(key=__key, value=result)
else:
if (
DecTracerConfig.restrict_global_accesses
Expand Down Expand Up @@ -166,6 +166,7 @@ def __init__(
paths: List[Path],
graph: Optional[DependencyGraph] = None,
strict: bool = True,
track_globals: bool = True,
allow_methods: bool = False,
skip_unhashable_globals: bool = True,
skip_globals_silently: bool = False,
Expand All @@ -174,8 +175,11 @@ def __init__(
self.graph = DependencyGraph() if graph is None else graph
self.paths = paths
self.strict = strict

self.track_globals = track_globals
self.skip_unhashable_globals = skip_unhashable_globals
self.skip_globals_silently = skip_globals_silently

self.allow_methods = allow_methods

self._traced = {}
Expand Down Expand Up @@ -240,15 +244,17 @@ def register_call(self, func: Callable) -> CallableNode:
assert parent is not None
self.graph.add_edge(parent, node)
### get globals
global_nodes = self.get_globals(func=func)
for global_node in global_nodes:
self.graph.add_edge(node, global_node)
# global_nodes = self.get_globals(func=func)
# for global_node in global_nodes:
# self.graph.add_edge(node, global_node)
if len(self.call_stack) == 1:
# this is the root of the graph
self.graph.roots.add(node.key)
return node

def register_global_access(self, key: str, value: Any):
if not self.track_globals:
return
assert len(self.call_stack) > 0
calling_node = self.call_stack[-1]
node = GlobalVarNode.from_obj(
Expand Down
2 changes: 2 additions & 0 deletions mandala/deps/tracers/tracer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ def __init__(
paths: List[Path],
strict: bool = True,
allow_methods: bool = False,
track_globals: bool = True,
):
self.call_stack: List[Optional[CallableNode]] = []
self.graph = DependencyGraph()
self.paths = paths
self.strict = strict
self.allow_methods = allow_methods
self.track_globals = track_globals

@abstractmethod
def __enter__(self):
Expand Down
10 changes: 8 additions & 2 deletions mandala/deps/versioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,26 @@ def __init__(
paths: List[Path],
TracerCls: type,
strict: bool,
track_globals: bool,
skip_unhashable_globals: bool,
skip_globals_silently: bool,
skip_missing_deps: bool,
skip_missing_silently: bool,
skip_globals_silently: bool,
track_methods: bool,
package_name: Optional[str] = None,
):
assert len(paths) in [0, 1]
self.paths = paths
self.TracerCls = TracerCls
self.strict = strict
self.skip_unhashable_globals = skip_unhashable_globals

self.skip_missing_deps = skip_missing_deps
self.skip_missing_silently = skip_missing_silently

self.track_globals = track_globals
self.skip_unhashable_globals = skip_unhashable_globals
self.skip_globals_silently = skip_globals_silently

self.allow_methods = track_methods
self.package_name = package_name
self.global_topology: DependencyGraph = DependencyGraph()
Expand Down Expand Up @@ -124,6 +129,7 @@ def make_tracer(self) -> TracerABC:
paths=[Config.mandala_path] + self.paths,
strict=self.strict,
allow_methods=self.allow_methods,
track_globals=self.track_globals,
skip_unhashable_globals=self.skip_unhashable_globals,
skip_globals_silently=self.skip_globals_silently,
)
Expand Down
2 changes: 2 additions & 0 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, db_path: str = ":memory:",
skip_missing_deps: bool = True, # whether to allow dependencies that were not found
skip_missing_silently: bool = False, # whether to skip such dependencies silently
deps_package: Optional[str] = None,
track_globals: bool = True,
):
self.db = DBAdapter(db_path=db_path)

Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self, db_path: str = ":memory:",
skip_globals_silently=skip_globals_silently,
skip_missing_deps=skip_missing_deps,
skip_missing_silently=skip_missing_silently,
track_globals=track_globals,
)
self.sources["versioner"] = versioner
else:
Expand Down

0 comments on commit 7267001

Please sign in to comment.