Skip to content

Commit f758d08

Browse files
isururanawakameta-codesync[bot]
authored andcommitted
Validate sharding plan correctness: Shard to rank assignment (#3495)
Summary: Pull Request resolved: #3495 - validate correct rank assignment - checks for None ranks - checks for ranks in correct rank range - check consistency in Manifold planner sharding plan rank assignment Reviewed By: aporialiao Differential Revision: D85878244 fbshipit-source-id: 09847323717609ece5d6480c5cf595f8a5fc4260
1 parent f7c74e4 commit f758d08

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,45 @@ def to_sharding_plan(
118118
return ShardingPlan(plan)
119119

120120

121+
def validate_rank_assignment(sharding_plan: ShardingPlan, topology: Topology) -> None:
122+
"""
123+
Validates that all shards in the given sharding plan have valid rank assignments.
124+
125+
This function iterates through each module and parameter in the provided sharding plan,
126+
checking that each shard's placement has a valid rank (i.e., not None, not negative, and
127+
less than the topology's world size). If any shard fails these checks, a PlannerError is raised.
128+
129+
Args:
130+
sharding_plan (ShardingPlan): The sharding plan to validate.
131+
topology (Topology): The topology containing world size information.
132+
133+
Raises:
134+
PlannerError: If any shard has an invalid rank assignment or if a sharding spec is missing.
135+
"""
136+
for module_name, module_plan in sharding_plan.plan.items():
137+
# pyre-ignore
138+
for param_name, param_plan in module_plan.items():
139+
if param_plan.sharding_spec is not None:
140+
for shard in param_plan.sharding_spec.shards:
141+
if shard.placement.rank() is None or shard.placement.rank() < 0:
142+
msg = f"Rank is not assigned for shard {shard}"
143+
logging.error(msg)
144+
raise PlannerError(
145+
error_type=PlannerErrorType.INVALID_RANK_ASSIGNMENT,
146+
message=msg,
147+
)
148+
if shard.placement.rank() >= topology.world_size:
149+
msg = f"Shard {shard} has rank {shard.placement.rank()} which is greater than world size {dist.get_world_size()}."
150+
logging.error(msg)
151+
raise PlannerError(
152+
error_type=PlannerErrorType.INVALID_RANK_ASSIGNMENT,
153+
message=msg,
154+
)
155+
else:
156+
msg = f"Sharding spec not found for {module_name}.{param_name}"
157+
logging.warning(msg)
158+
159+
121160
def extract_plan(
122161
search_space: List[ShardingOption],
123162
loaded_sharding_options: Dict[int, ShardingOption],
@@ -634,6 +673,8 @@ def plan(
634673
timeout_seconds=self._timeout_seconds,
635674
),
636675
)
676+
677+
validate_rank_assignment(sharding_plan, self._topology)
637678
return sharding_plan
638679
else:
639680
global_storage_capacity = reduce(
@@ -997,6 +1038,7 @@ def plan(
9971038
sharding_plan = to_sharding_plan(
9981039
best_plan, self._topology_groups[group]
9991040
)
1041+
validate_rank_assignment(sharding_plan, self._topology_groups[group])
10001042
best_plans.append(sharding_plan)
10011043

10021044
end_time = perf_counter()
@@ -1090,4 +1132,5 @@ def plan(
10901132
+ last_planner_error_info,
10911133
)
10921134

1093-
return _merge_plans(best_plans)
1135+
sharding_plan = _merge_plans(best_plans)
1136+
return sharding_plan

torchrec/distributed/planner/tests/test_planners.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,28 @@ def setUp(self) -> None:
7878
)
7979
self.planner = EmbeddingShardingPlanner(topology=self.topology)
8080

81+
def test_tw_rank_assignment(self) -> None:
82+
tables = [
83+
EmbeddingBagConfig(
84+
num_embeddings=100,
85+
embedding_dim=64,
86+
name="table_" + str(i),
87+
feature_names=["feature_" + str(i)],
88+
)
89+
for i in range(4)
90+
]
91+
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
92+
sharding_plan = self.planner.plan(module=model, sharders=[TWSharder()])
93+
ranks = [
94+
cast(List[int], param_shard.ranks)
95+
for param_shard in cast(
96+
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"]
97+
).values()
98+
]
99+
for rank_list in ranks:
100+
for rank in rank_list:
101+
self.assertTrue(0 <= rank <= 1, f"Rank {rank} not in [0,1]")
102+
81103
def test_tw_solution(self) -> None:
82104
tables = [
83105
EmbeddingBagConfig(

torchrec/distributed/planner/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ class PlannerErrorType(Enum):
799799
OTHER = "other"
800800
PLANNER_INPUT_CONTEXT_MISMATCH = "planner_input_context_mismatch"
801801
PLAN_LOADING_FAILED = "plan_loading_failed"
802+
INVALID_RANK_ASSIGNMENT = "invalid_rank_assignment"
802803

803804

804805
class PlannerError(Exception):

0 commit comments

Comments
 (0)