Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] add integration test for the generation script #741

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,20 @@ def test_generate(
"The input prompt is empty, model will respond from a empty sequence."
)

utils.set_determinism(seed)
if seed is not None:
torch.manual_seed(seed)
# PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1]
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# env var for deterministic CuBLAS
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the set_determinism() removed from the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call of set_determinism() was added before @wconstab's PR to make RNG right, which broke this line. Given how seed is explicitly used here, I'm not sure if the code over there would indicate (would it still be correct?). Need @XilunWu's help on understanding more.

This PR tries to restore the behavior before the "BC-breaking change" and guard on its running.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the set_determinism change for correct RNG affecting this code? I would hope that calling `set_determinism(seed, deterministic=True) would be equivalent to some manual stuff done here. What is the issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline; refactored to use set_determinism


if seed is None:
logger.info("Deterministic sampling off")
else:
logger.info(f"Deterministic sampling on. Using seed: {seed}")
else:
logger.info("Deterministic sampling off")

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
Expand Down
37 changes: 30 additions & 7 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,19 @@ def build_test_list():
"fsdp+tp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--optimizer.early_step_in_backward",
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand All @@ -382,14 +395,14 @@ def build_test_list():
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--optimizer.early_step_in_backward",
],
[
# placeholder for the generation script's generate step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this part WIP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I wrote it this way because for generation, the first step is to call run_llama_train.sh and the second step is to call the generating script
https://github.com/pytorch/torchtitan/pull/741/files#diff-3b751e36d12b5fa68ae66727b4d8c6ef2cce12d4f2444b46fd882942c9f4a87fR441-R447

To make it less hacky I think we need some surgery to the file & classes.

],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
"Generation script test",
"test_generate",
ngpu=2,
),
]
return integration_tests_flavors
Expand All @@ -412,7 +425,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

for override_arg in test_flavor.override_args:
for idx, override_arg in enumerate(test_flavor.override_args):
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
# dump compile trace for debugging purpose
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
Expand All @@ -428,6 +441,16 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)

# save checkpoint (idx == 0) and load it for generation (idx == 1)
if test_name == "test_generate" and idx == 1:
cmd = (
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 "
"PROMPT='What is the meaning of life?' "
f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json"
)

result = _run_cmd(cmd)
logger.info(result.stdout)
if result.returncode != 0:
Expand Down
Loading