diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 31cbfad4..02d2b753 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -217,8 +217,8 @@ def _parseExecutionTrace( """ # Execution Trace PG_ID types availability - ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False - ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False + ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3) + ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3) initOps = [] newCommsTrace = [] diff --git a/train/compute/python/test/data/1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz b/train/compute/python/test/data/1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz new file mode 100644 index 00000000..87b3e640 Binary files /dev/null and b/train/compute/python/test/data/1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz differ diff --git a/train/compute/python/test/test_execution_trace.py b/train/compute/python/test/test_execution_trace.py new file mode 100644 index 00000000..e96ae595 --- /dev/null +++ b/train/compute/python/test/test_execution_trace.py @@ -0,0 +1,38 @@ +import gzip +import json +import os +import unittest + +from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace +from param_bench.train.compute.python.tools.validate_trace import TraceValidator + +CURR_DIR = os.path.dirname(os.path.realpath(__file__)) + + +class TestTraceLoadAndValidate(unittest.TestCase): + def setUp(self): + self.trace_base = os.path.join(CURR_DIR, "data") + + def _test_and_validate_trace(self, trace_file): + with ( + gzip.open(trace_file, "rb") + if trace_file.endswith("gz") + else open(trace_file, "r") + ) as execution_data: + execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data)) + t = TraceValidator(execution_trace) + self.assertTrue(t.validate()) + return t, execution_trace + + def test_trace_load_resnet_1gpu(self): + et_file = os.path.join( + self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz" + ) + t, et = self._test_and_validate_trace(et_file) + self.assertGreater(t.num_ops(), 1000) + self.assertEqual(t.num_comm_ops(), 12) + self.assertEqual(t.num_triton_ops(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/train/compute/python/tools/execution_trace.py b/train/compute/python/tools/execution_trace.py index f7553ec5..00e23a75 100644 --- a/train/compute/python/tools/execution_trace.py +++ b/train/compute/python/tools/execution_trace.py @@ -350,6 +350,17 @@ def __init__(self, json): # remove all dataloader ops self.remove_dataloader_ops() + def _versiontuple(self, v: str) -> Tuple[int]: + return tuple(map(int, (v.split(".")))) + + def schema_pytorch(self) -> Tuple[int]: + return self._versiontuple(self.schema.split("-")[0]) + + def schema_chakra(self) -> Tuple[int]: + if "-" not in self.schema: + return (0, 0, 0) + return self._versiontuple(self.schema.split("-")[1]) + @staticmethod def _read_attrs(node: Dict[str, Any]) -> Tuple: attr_types = { diff --git a/train/compute/python/tools/validate_trace.py b/train/compute/python/tools/validate_trace.py new file mode 100644 index 00000000..1b6f89ba --- /dev/null +++ b/train/compute/python/tools/validate_trace.py @@ -0,0 +1,113 @@ +from __future__ import ( + absolute_import, + annotations, + division, + print_function, + unicode_literals, +) + +import gzip +import json + +from .execution_trace import ExecutionTrace + + +class TraceValidator: + + def __init__(self, execution_trace: ExecutionTrace): + self.et = execution_trace + + def _ops(self): + return (n for n in self.et.nodes.values() if n.is_op()) + + def _validate_ops(self) -> bool: + """Make sure the pytorch operators are valid""" + ops = self._ops() + for op in ops: + if op.name == "": + print(f"op should have valid name, node id = {op.id}") + + # if len(list(op.get_outputs())) + len(list(op.get_inputs())) == 0: + # print(f"op should have outputs or inputs, node = {op.name}") + # FIXME see "autograd::engine::evaluate_function: DivBackward1" + # currently let's skip this + # return False + return True + + def _validate_tree(self) -> bool: + """TBD validate that the generated datastructure is a tree + with parent/child relationship. We can use pydot or networkx libs for this + """ + return True + + def _validate_param_comms(self) -> bool: + """Check if param comms has correct attributes""" + # This should use the comms parser, for now something simple will be fine + # https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256 + + if self.et.schema_pytorch() < (1, 0, 2): + return True + + def check_comms_node(n) -> bool: + """TODO use comms parser""" + has_pg_id = False + # Slightly hacky but find a argument with tuple type + for arg in n.get_inputs(): + if arg[0] == "Tuple[String,String]": + print(f" {n.name}, process group args = {arg}") + has_pg_id = True + return has_pg_id + + return all( + check_comms_node(n) + for n in self.et.nodes.values() + if n.is_op() and n.name == "record_param_comms" + ) + + def _validate_triton(self) -> bool: + """Make sure triton kernels have correct values + TODO update for checking if kernel files are captured. + """ + return True + + def validate(self) -> bool: + return all( + [ + self._validate_ops(), + self._validate_tree(), + self._validate_param_comms(), + self._validate_triton(), + ] + ) + + def num_ops(self) -> int: + return len(list(self._ops())) + + def num_comm_ops(self) -> int: + return sum(1 for op in self._ops() if op.name == "record_param_comms") + + def num_triton_ops(self) -> int: + return sum(1 for op in self._ops() if "triton" in op.name) + + +def main(): + import sys + + execution_json = sys.argv[1] + + with ( + gzip.open(execution_json, "rb") + if execution_json.endswith("gz") + else open(execution_json, "r") + ) as execution_data: + execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data)) + t = TraceValidator(execution_trace) + print( + f"num ops = {t.num_ops()}, num comms = {t.num_comm_ops()}, " + f"num triton ops = {t.num_triton_ops()}" + ) + print("Trace validation result = ", t.validate()) + + +if __name__ == "__main__": + main()