Skip to content

Commit

Permalink
Increase coverage run_baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianreimer committed May 7, 2021
1 parent 9a65412 commit 3fb72d1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
7 changes: 4 additions & 3 deletions dacbench/run_baselines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import itertools
import sys
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -126,7 +127,7 @@ def run_policy(results_path, benchmark_name, num_episodes, policy, seeds=np.aran
logger.close()


def main():
def main(args):
parser = argparse.ArgumentParser(
description="Run simple baselines for DAC benchmarks",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
Expand Down Expand Up @@ -203,7 +204,7 @@ def main():
default=0,
help="Fixes random actions for n steps",
)
args = parser.parse_args()
args = parser.parse_args(args)

if args.benchmarks is None:
benchs = benchmarks.__all__
Expand Down Expand Up @@ -246,4 +247,4 @@ def main():


if __name__ == "__main__":
main()
main(sys.argv[1:])
76 changes: 64 additions & 12 deletions tests/test_run_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@
from pathlib import Path
from dacbench.logger import load_logs, log2dataframe

from dacbench.run_baselines import run_random, DISCRETE_ACTIONS, run_static

benchmarks = [
"SigmoidBenchmark",
"LubyBenchmark",
"FastDownwardBenchmark",
"CMAESBenchmark",
"ModeaBenchmark",
"SGDBenchmark",
]
from dacbench.run_baselines import (
run_random,
DISCRETE_ACTIONS,
run_static,
run_dynamic_policy,
run_optimal,
main,
)


class TestRunBaselines(unittest.TestCase):
def run_random_test_with_benchmark(self, benchmark):
seeds = [42]
fixed = 2
num_episodes = 10
num_episodes = 3

with tempfile.TemporaryDirectory() as temp_dir:
result_path = Path(temp_dir)
Expand Down Expand Up @@ -60,7 +58,7 @@ def test_run_random_ModeaBenchmark(self):

def run_static_test_with_benchmark(self, benchmark):
seeds = [42]
num_episodes = 10
num_episodes = 3
action = DISCRETE_ACTIONS[benchmark][0]
with tempfile.TemporaryDirectory() as temp_dir:
result_path = Path(temp_dir)
Expand Down Expand Up @@ -98,3 +96,57 @@ def test_run_static_SGDBenchmark(self):

def test_run_static_ModeaBenchmark(self):
self.run_static_test_with_benchmark("ModeaBenchmark")

def test_run_dynamic_policy_CMAESBenchmark(self):
benchmark = "CMAESBenchmark"
seeds = [42]
num_episodes = 3
with tempfile.TemporaryDirectory() as temp_dir:
result_path = Path(temp_dir)

run_dynamic_policy(result_path, benchmark, num_episodes, seeds)

expected_experiment_path = result_path / benchmark / f"csa_{seeds[0]}"
self.assertTrue(expected_experiment_path.exists())

performance_tracking_log = (
expected_experiment_path / "PerformanceTrackingWrapper.jsonl"
)
self.assertTrue(performance_tracking_log.exists())

logs = log2dataframe(load_logs(performance_tracking_log))
self.assertEqual(len(logs), num_episodes)
self.assertTrue((logs["seed"] == seeds[0]).all())

def run_optimal_test_with_benchmark(self, benchmark):
seeds = [42]
num_episodes = 3
with tempfile.TemporaryDirectory() as temp_dir:
result_path = Path(temp_dir)

run_optimal(result_path, benchmark, num_episodes, seeds)

expected_experiment_path = result_path / benchmark / f"optimal_{seeds[0]}"
self.assertTrue(expected_experiment_path.exists())

performance_tracking_log = (
expected_experiment_path / "PerformanceTrackingWrapper.jsonl"
)
self.assertTrue(performance_tracking_log.exists())

logs = log2dataframe(load_logs(performance_tracking_log))
self.assertEqual(len(logs), num_episodes)
self.assertTrue((logs["seed"] == seeds[0]).all())

def test_run_optimal_LubyBenchmark(self):
self.run_optimal_test_with_benchmark("LubyBenchmark")

def test_run_optimal_SigmoidBenchmark(self):
self.run_optimal_test_with_benchmark("SigmoidBenchmark")

def test_run_optimal_FastDownwardBenchmark(self):
self.run_optimal_test_with_benchmark("FastDownwardBenchmark")

def test_main_help(self):
with self.assertRaises(SystemExit):
main(["--help"])

0 comments on commit 3fb72d1

Please sign in to comment.