Skip to content

Commit

Permalink
Add a trace validator helper and simple unit test for
Browse files Browse the repository at this point in the history
Summary:
## Summary
* Adds a trace validation tool that can check PyTorch host execution traces. This is helpful for schema changes, and integration testing
* Add a unit test to check if execution_trace.py works correctly on preset traces.
* Minor: helpers to read semantic version of pytorch and chakra!

Reviewed By: shengbao-zheng

Differential Revision: D56325885
  • Loading branch information
briancoutinho authored and facebook-github-bot committed Apr 22, 2024
1 parent 7868e09 commit 3bf0b05
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def _parseExecutionTrace(
"""
# Execution Trace PG_ID types availability
ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False
ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False
ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3)
ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3)

initOps = []
newCommsTrace = []
Expand Down
Binary file not shown.
38 changes: 38 additions & 0 deletions train/compute/python/test/test_execution_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import gzip
import json
import os
import unittest

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from param_bench.train.compute.python.tools.validate_trace import TraceValidator

CURR_DIR = os.path.dirname(os.path.realpath(__file__))


class TestTraceLoadAndValidate(unittest.TestCase):
def setUp(self):
self.trace_base = os.path.join(CURR_DIR, "data")

def _test_and_validate_trace(self, trace_file):
with (
gzip.open(trace_file, "rb")
if trace_file.endswith("gz")
else open(trace_file, "r")
) as execution_data:
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
t = TraceValidator(execution_trace)
self.assertTrue(t.validate())
return t, execution_trace

def test_trace_load_resnet_1gpu(self):
et_file = os.path.join(
self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz"
)
t, et = self._test_and_validate_trace(et_file)
self.assertGreater(t.num_ops(), 1000)
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_triton_ops(), 0)


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@ def __init__(self, json):
# remove all dataloader ops
self.remove_dataloader_ops()

def _versiontuple(self, v: str) -> Tuple[int]:
return tuple(map(int, (v.split("."))))

def schema_pytorch(self) -> Tuple[int]:
return self._versiontuple(self.schema.split("-")[0])

def schema_chakra(self) -> Tuple[int]:
if "-" not in self.schema:
return (0, 0, 0)
return self._versiontuple(self.schema.split("-")[1])

@staticmethod
def _read_attrs(node: Dict[str, Any]) -> Tuple:
attr_types = {
Expand Down
113 changes: 113 additions & 0 deletions train/compute/python/tools/validate_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)

import gzip
import json

from .execution_trace import ExecutionTrace


class TraceValidator:

def __init__(self, execution_trace: ExecutionTrace):
self.et = execution_trace

def _ops(self):
return (n for n in self.et.nodes.values() if n.is_op())

def _validate_ops(self) -> bool:
"""Make sure the pytorch operators are valid"""
ops = self._ops()
for op in ops:
if op.name == "":
print(f"op should have valid name, node id = {op.id}")

# if len(list(op.get_outputs())) + len(list(op.get_inputs())) == 0:
# print(f"op should have outputs or inputs, node = {op.name}")
# FIXME see "autograd::engine::evaluate_function: DivBackward1"
# currently let's skip this
# return False
return True

def _validate_tree(self) -> bool:
"""TBD validate that the generated datastructure is a tree
with parent/child relationship. We can use pydot or networkx libs for this
"""
return True

def _validate_param_comms(self) -> bool:
"""Check if param comms has correct attributes"""
# This should use the comms parser, for now something simple will be fine
# https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256

if self.et.schema_pytorch() < (1, 0, 2):
return True

def check_comms_node(n) -> bool:
"""TODO use comms parser"""
has_pg_id = False
# Slightly hacky but find a argument with tuple type
for arg in n.get_inputs():
if arg[0] == "Tuple[String,String]":
print(f" {n.name}, process group args = {arg}")
has_pg_id = True
return has_pg_id

return all(
check_comms_node(n)
for n in self.et.nodes.values()
if n.is_op() and n.name == "record_param_comms"
)

def _validate_triton(self) -> bool:
"""Make sure triton kernels have correct values
TODO update for checking if kernel files are captured.
"""
return True

def validate(self) -> bool:
return all(
[
self._validate_ops(),
self._validate_tree(),
self._validate_param_comms(),
self._validate_triton(),
]
)

def num_ops(self) -> int:
return len(list(self._ops()))

def num_comm_ops(self) -> int:
return sum(1 for op in self._ops() if op.name == "record_param_comms")

def num_triton_ops(self) -> int:
return sum(1 for op in self._ops() if "triton" in op.name)


def main():
import sys

execution_json = sys.argv[1]

with (
gzip.open(execution_json, "rb")
if execution_json.endswith("gz")
else open(execution_json, "r")
) as execution_data:
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
t = TraceValidator(execution_trace)
print(
f"num ops = {t.num_ops()}, num comms = {t.num_comm_ops()}, "
f"num triton ops = {t.num_triton_ops()}"
)
print("Trace validation result = ", t.validate())


if __name__ == "__main__":
main()

0 comments on commit 3bf0b05

Please sign in to comment.