From 5ccfaf6cf127c4bf1649995105a4b9378cbf07ce Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sun, 15 Dec 2024 18:17:32 -0800 Subject: [PATCH] [BE] add integration test for the generation script [ghstack-poisoned] --- scripts/generate/test_generate.py | 16 +++++++++---- tests/integration_tests.py | 37 +++++++++++++++++++++++++------ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index f46c0967..5610cbe8 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -95,12 +95,20 @@ def test_generate( "The input prompt is empty, model will respond from a empty sequence." ) - utils.set_determinism(seed) + if seed is not None: + torch.manual_seed(seed) + # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + # env var for deterministic CuBLAS + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - if seed is None: - logger.info("Deterministic sampling off") - else: logger.info(f"Deterministic sampling on. Using seed: {seed}") + else: + logger.info("Deterministic sampling off") world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index dcd13d6d..93cad886 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -368,6 +368,19 @@ def build_test_list(): "fsdp+tp+cp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 2", + "--training.enable_cpu_offload True", + "--optimizer.early_step_in_backward", + ], + ], + "Enable CPU Offload with PP", + "enable_cpu_offload+PP", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -382,14 +395,14 @@ def build_test_list(): [ [ "--checkpoint.enable_checkpoint", - "--experimental.pipeline_parallel_degree 2", - "--training.enable_cpu_offload True", - "--optimizer.early_step_in_backward", + ], + [ + # placeholder for the generation script's generate step ], ], - "Enable CPU Offload with PP", - "enable_cpu_offload+PP", - ngpu=4, + "Generation script test", + "test_generate", + ngpu=2, ), ] return integration_tests_flavors @@ -412,7 +425,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) - for override_arg in test_flavor.override_args: + for idx, override_arg in enumerate(test_flavor.override_args): cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh" # dump compile trace for debugging purpose cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd @@ -428,6 +441,16 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): logger.info( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) + + # save checkpoint (idx == 0) and load it for generation (idx == 1) + if test_name == "test_generate" and idx == 1: + cmd = ( + f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " + f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 " + "PROMPT='What is the meaning of life?' " + f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json" + ) + result = _run_cmd(cmd) logger.info(result.stdout) if result.returncode != 0: