Skip to content

Commit

Permalink
Add benchmark experiment test (pytorch#6081)
Browse files Browse the repository at this point in the history
  • Loading branch information
frgossen authored Dec 11, 2023
1 parent b224350 commit dbd3e33
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
1 change: 1 addition & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

12 changes: 10 additions & 2 deletions test/benchmarks/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
LOGFILE=/tmp/pytorch_benchmarks_test.log
VERBOSITY=0

# Make benchmark module available as it is not part of torch_xla.
export PYTHONPATH=$PYTHONPATH:$CDIR/../../benchmarks/

# Note [Keep Going]
#
# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error.
Expand All @@ -19,10 +22,10 @@ do
case $OPTION in
L)
LOGFILE=
;;
;;
V)
VERBOSITY=$OPTARG
;;
;;
esac
done
shift $(($OPTIND - 1))
Expand All @@ -35,8 +38,13 @@ function run_make_tests {
make -C $CDIR $MAKE_V all
}

function run_python_tests {
python3 "$CDIR/test_benchmark_experiment.py"
}

function run_tests {
run_make_tests
run_python_tests
}

if [ "$LOGFILE" != "" ]; then
Expand Down
24 changes: 24 additions & 0 deletions test/benchmarks/test_benchmark_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest

from benchmark_experiment import BenchmarkExperiment


class BenchmarkExperimentTest(unittest.TestCase):

def test_to_dict(self):
be = BenchmarkExperiment("some name", "cpu", "PJRT", "some xla_flags",
"openxla", "train", "123")
actual = be.to_dict()
self.assertEqual(8, len(actual))
self.assertEqual("some name", actual["experiment_name"])
self.assertEqual("cpu", actual["accelerator"])
self.assertTrue("accelerator_model" in actual)
self.assertEqual("PJRT", actual["xla"])
self.assertEqual("some xla_flags", actual["xla_flags"])
self.assertEqual("openxla", actual["dynamo"])
self.assertEqual("train", actual["test"])
self.assertEqual("123", actual["batch_size"])


if __name__ == '__main__':
unittest.main()

0 comments on commit dbd3e33

Please sign in to comment.