Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a trace validator helper and simple unit test for #102

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def _parseExecutionTrace(

"""
# Execution Trace PG_ID types availability
ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False
ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False
ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3)
ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3)

initOps = []
newCommsTrace = []
Expand Down
Binary file not shown.
38 changes: 38 additions & 0 deletions train/compute/python/test/test_execution_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import gzip
import json
import os
import unittest

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from param_bench.train.compute.python.tools.validate_trace import TraceValidator

CURR_DIR = os.path.dirname(os.path.realpath(__file__))


class TestTraceLoadAndValidate(unittest.TestCase):
def setUp(self):
self.trace_base = os.path.join(CURR_DIR, "data")

def _test_and_validate_trace(self, trace_file):
with (
gzip.open(trace_file, "rb")
if trace_file.endswith("gz")
else open(trace_file, "r")
) as execution_data:
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
t = TraceValidator(execution_trace)
self.assertTrue(t.validate())
return t, execution_trace

def test_trace_load_resnet_1gpu(self):
et_file = os.path.join(
self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz"
)
t, et = self._test_and_validate_trace(et_file)
self.assertGreater(t.num_ops(), 1000)
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_triton_ops(), 0)


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@ def __init__(self, json):
# remove all dataloader ops
self.remove_dataloader_ops()

def _versiontuple(self, v: str) -> Tuple[int]:
return tuple(map(int, (v.split("."))))

def schema_pytorch(self) -> Tuple[int]:
return self._versiontuple(self.schema.split("-")[0])

def schema_chakra(self) -> Tuple[int]:
if "-" not in self.schema:
return (0, 0, 0)
return self._versiontuple(self.schema.split("-")[1])

@staticmethod
def _read_attrs(node: Dict[str, Any]) -> Tuple:
attr_types = {
Expand Down
113 changes: 113 additions & 0 deletions train/compute/python/tools/validate_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)

import gzip
import json

from .execution_trace import ExecutionTrace


class TraceValidator:

def __init__(self, execution_trace: ExecutionTrace):
self.et = execution_trace

def _ops(self):
return (n for n in self.et.nodes.values() if n.is_op())

def _validate_ops(self) -> bool:
"""Make sure the pytorch operators are valid"""
ops = self._ops()
for op in ops:
if op.name == "":
print(f"op should have valid name, node id = {op.id}")

# if len(list(op.get_outputs())) + len(list(op.get_inputs())) == 0:
# print(f"op should have outputs or inputs, node = {op.name}")
# FIXME see "autograd::engine::evaluate_function: DivBackward1"
# currently let's skip this
# return False
return True

def _validate_tree(self) -> bool:
"""TBD validate that the generated datastructure is a tree
with parent/child relationship. We can use pydot or networkx libs for this
"""
return True

def _validate_param_comms(self) -> bool:
"""Check if param comms has correct attributes"""
# This should use the comms parser, for now something simple will be fine
# https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256

if self.et.schema_pytorch() < (1, 0, 2):
return True

def check_comms_node(n) -> bool:
"""TODO use comms parser"""
has_pg_id = False
# Slightly hacky but find a argument with tuple type
for arg in n.get_inputs():
if arg[0] == "Tuple[String,String]":
print(f" {n.name}, process group args = {arg}")
has_pg_id = True
return has_pg_id

return all(
check_comms_node(n)
for n in self.et.nodes.values()
if n.is_op() and n.name == "record_param_comms"
)

def _validate_triton(self) -> bool:
"""Make sure triton kernels have correct values
TODO update for checking if kernel files are captured.
"""
return True

def validate(self) -> bool:
return all(
[
self._validate_ops(),
self._validate_tree(),
self._validate_param_comms(),
self._validate_triton(),
]
)

def num_ops(self) -> int:
return len(list(self._ops()))

def num_comm_ops(self) -> int:
return sum(1 for op in self._ops() if op.name == "record_param_comms")

def num_triton_ops(self) -> int:
return sum(1 for op in self._ops() if "triton" in op.name)


def main():
import sys

execution_json = sys.argv[1]

with (
gzip.open(execution_json, "rb")
if execution_json.endswith("gz")
else open(execution_json, "r")
) as execution_data:
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
t = TraceValidator(execution_trace)
print(
f"num ops = {t.num_ops()}, num comms = {t.num_comm_ops()}, "
f"num triton ops = {t.num_triton_ops()}"
)
print("Trace validation result = ", t.validate())


if __name__ == "__main__":
main()
Loading