Skip to content

Commit

Permalink
Revert "change pipeline to use object's UUID as key when possible"
Browse files Browse the repository at this point in the history
This reverts commit a028b17.

Signed-off-by: Julio Faracco <[email protected]>
  • Loading branch information
jcfaracco committed Nov 27, 2023
1 parent 9e753f7 commit 418cd45
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions dasf/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ def execute_callbacks(self, func_name: str, *args, **kwargs):
for callback in self._callbacks:
getattr(callback, func_name)(*args, **kwargs)

def __add_into_dag(self, obj, func_name, key, parameters=None, itself=None):
def __add_into_dag(self, obj, func_name, parameters=None, itself=None):
key = hash(obj)

if key not in self._dag_table:
self._dag.add_node(key)
self._dag_g.node(str(key), func_name)
self._dag_table[key] = dict()
self._dag_table[key]["fn"] = obj
self._dag_table[key]["name"] = func_name
Expand All @@ -77,12 +78,17 @@ def __add_into_dag(self, obj, func_name, key, parameters=None, itself=None):
# If we are adding a object which require parameters,
# we need to make sure they are mapped into DAG.
for k, v in parameters.items():
dep_obj, dep_func_name, dep_key, _ = self.__inspect_element(v)
dep_obj, dep_func_name, _ = self.__inspect_element(v)
self.add(dep_obj)
if not self._dag.has_node(str(key)):
self._dag_g.node(str(key), func_name)

if not self._dag.has_node(str(hash(dep_obj))):
self._dag_g.node(str(hash(dep_obj)), dep_func_name)

self._dag.add_edge(dep_key, key)
self._dag.add_edge(hash(dep_obj), key)

self._dag_g.edge(str(dep_key), str(key), label=k)
self._dag_g.edge(str(hash(dep_obj)), str(key), label=k)

def __inspect_element(self, obj):
from dasf.datasets.base import Dataset
Expand All @@ -94,13 +100,11 @@ def generate_name(class_name, func_name):
if inspect.isfunction(obj) and callable(obj):
return (obj,
obj.__qualname__,
hash(obj),
None)
elif inspect.ismethod(obj):
return (obj,
generate_name(obj.__self__.__class__.__name__,
obj.__name__),
obj.__self__.get_uuid() if hasattr(obj.__self__, "get_uuid") else hash(obj),
obj.__self__)
elif issubclass(obj.__class__, Dataset) and hasattr(obj, "load"):
# (Disabled) Register dataset for reusability
Expand All @@ -109,19 +113,16 @@ def generate_name(class_name, func_name):
return (obj.load,
generate_name(obj.__class__.__name__,
"load"),
obj.get_uuid(),
obj)
elif issubclass(obj.__class__, Fit) and hasattr(obj, "fit"):
return (obj.fit,
generate_name(obj.__class__.__name__,
"fit"),
obj.get_uuid(),
obj)
elif issubclass(obj.__class__, Transform) and hasattr(obj, "transform"):
return (obj.transform,
generate_name(obj.__class__.__name__,
"transform"),
obj.get_uuid(),
obj)
else:
raise ValueError(
Expand All @@ -130,8 +131,8 @@ def generate_name(class_name, func_name):
)

def add(self, obj, **kwargs):
obj, func_name, uuid, objref = self.__inspect_element(obj)
self.__add_into_dag(obj, func_name, uuid, kwargs, objref)
obj, func_name, objref = self.__inspect_element(obj)
self.__add_into_dag(obj, func_name, kwargs, objref)

return self

Expand All @@ -142,10 +143,6 @@ def visualize(self, filename=None):
return self._dag_g
return self._dag_g.view(filename)

def save_image(self, filename):
self._dag_g.render(outfile=filename, cleanup=True)


def __register_dataset(self, dataset):
key = str(hash(dataset.load))
kwargs = {key: dataset}
Expand All @@ -161,8 +158,8 @@ def __execute(self, func, params, name):
new_params = dict()
if params:
for k, v in params.items():
_, _, uuid, *_ = self.__inspect_element(v)
req_key = uuid
dep_obj, *_ = self.__inspect_element(v)
req_key = hash(dep_obj)

new_params[k] = self._dag_table[req_key]["ret"]

Expand All @@ -174,13 +171,15 @@ def __execute(self, func, params, name):
return ret

def get_result_from(self, obj):
_, obj_name, key, *_ = self.__inspect_element(obj)
_, obj_name, *_ = self.__inspect_element(obj)

for key in self._dag_table:
if self._dag_table[key]["name"] == obj_name:
if self._dag_table[key]["ret"] is None:
raise Exception("Pipeline was not executed yet.")
return self._dag_table[key]["ret"]

if key in self._dag_table:
if self._dag_table[key]["ret"] is None:
raise Exception("Pipeline was not executed yet.")
return self._dag_table[key]["ret"]
raise Exception(f"Function {obj_name}-{key} was not added into pipeline.")
raise Exception(f"Function {obj_name} was not added into pipeline.")

def run(self):
if not nx.is_directed_acyclic_graph(self._dag):
Expand Down

0 comments on commit 418cd45

Please sign in to comment.