Skip to content

Commit

Permalink
Make test_runner.py warn on non-empty output dir
Browse files Browse the repository at this point in the history
also wrap logic into functions and clean up global vars

ghstack-source-id: 815c582011611a71005cc22bbd14310900465377
Pull Request resolved: #343
  • Loading branch information
wconstab committed May 17, 2024
1 parent a2ace60 commit f2c3a11
Showing 1 changed file with 89 additions and 73 deletions.
162 changes: 89 additions & 73 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
import tomli as tomllib


parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
args = parser.parse_args()


@dataclass
class OverrideDefinitions:
"""
Expand All @@ -32,77 +27,77 @@ class OverrideDefinitions:
test_descr: str = "default"


CONFIG_DIR = "./train_configs"

"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
f"--job.dump_folder {args.output_dir}/default/",
],
],
"Default",
),
OverrideDefinitions(
[
def build_test_list(args):
"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
"--training.compile",
f"--job.dump_folder {args.output_dir}/1d_compile/",
[
f"--job.dump_folder {args.output_dir}/default/",
],
],
],
"1D compile",
),
OverrideDefinitions(
[
"Default",
),
OverrideDefinitions(
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/eager_2d/",
[
"--training.compile",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
],
"Eager mode 2DParallel",
),
OverrideDefinitions(
[
"1D compile",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/eager_2d/",
],
],
"Eager mode 2DParallel",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
"--training.steps 20",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
],
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
"--training.steps 20",
],
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
),
OverrideDefinitions(
[
"Checkpoint Integration Test - Save Load Full Checkpoint",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
"--checkpoint.model_weights_only",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
"--checkpoint.model_weights_only",
],
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
),
OverrideDefinitions(
[
"Checkpoint Integration Test - Save Model Weights Only fp32",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
]
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
]
return integration_tests_flavors


def run_test(test_flavor: OverrideDefinitions, full_path: str):
Expand All @@ -128,12 +123,33 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
)


for config_file in os.listdir(CONFIG_DIR):
if config_file.endswith(".toml"):
full_path = os.path.join(CONFIG_DIR, config_file)
with open(full_path, "rb") as f:
config = tomllib.load(f)
is_integration_test = config["job"].get("use_for_integration_test", False)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)
def run_tests(args):
integration_tests_flavors = build_test_list(args)
for config_file in os.listdir(args.config_dir):
if config_file.endswith(".toml"):
full_path = os.path.join(args.config_dir, config_file)
with open(full_path, "rb") as f:
config = tomllib.load(f)
is_integration_test = config["job"].get(
"use_for_integration_test", False
)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("--config_dir", default="./train_configs")
args = parser.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if os.listdir(args.output_dir):
raise RuntimeError("Please provide an empty output directory.")
run_tests(args)


if __name__ == "__main__":
main()

0 comments on commit f2c3a11

Please sign in to comment.