Skip to content

Commit

Permalink
Merge branch 'main' into optional_checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 authored Feb 7, 2025
2 parents b1f1d5d + 5940dde commit 582fe7d
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 378 deletions.
11 changes: 0 additions & 11 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,6 @@ def estimate_memory(job_config: JobConfig):
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"
Expand Down
20 changes: 10 additions & 10 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,6 @@ def build_test_list():
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
"--model.norm_type=fused_rmsnorm",
],
],
"2D eager with fused_rmsnorm",
"2d_eager_fused_rmsnorm",
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -418,6 +408,16 @@ def build_test_list():
"test_generate",
ngpu=2,
),
OverrideDefinitions(
[
[
"--training.fsdp_reshard_after_forward always",
],
],
"Test always resharding after forward pass",
"fsdp_reshard_always",
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
72 changes: 0 additions & 72 deletions tests/unit_tests/test_fused_rms_norm_dtensor.py

This file was deleted.

20 changes: 19 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
choices=["layernorm", "np_layernorm", "rmsnorm"],
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down Expand Up @@ -284,6 +285,23 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.fsdp_reshard_after_forward",
type=str,
default="default",
choices=["default", "always", "never"],
help="""
`reshard_after_forward` specifies the policy for applying `reshard_after_forward`
within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
trading off memory and communication. See torch's `fully_shard` API for more documentation
on `reshard_after_forward`.
The supported policies include "default", "always" and "never":
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal
scenarios.
- "always" will enable `reshard_after_forward` for all forward passes.
- "never" will disable `reshard_after_forward` for all forward passes.
""",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
action="store_true",
Expand Down
Loading

0 comments on commit 582fe7d

Please sign in to comment.