@@ -118,6 +118,45 @@ def to_sharding_plan(
118118 return ShardingPlan (plan )
119119
120120
121+ def validate_rank_assignment (sharding_plan : ShardingPlan , topology : Topology ) -> None :
122+ """
123+ Validates that all shards in the given sharding plan have valid rank assignments.
124+
125+ This function iterates through each module and parameter in the provided sharding plan,
126+ checking that each shard's placement has a valid rank (i.e., not None, not negative, and
127+ less than the topology's world size). If any shard fails these checks, a PlannerError is raised.
128+
129+ Args:
130+ sharding_plan (ShardingPlan): The sharding plan to validate.
131+ topology (Topology): The topology containing world size information.
132+
133+ Raises:
134+ PlannerError: If any shard has an invalid rank assignment or if a sharding spec is missing.
135+ """
136+ for module_name , module_plan in sharding_plan .plan .items ():
137+ # pyre-ignore
138+ for param_name , param_plan in module_plan .items ():
139+ if param_plan .sharding_spec is not None :
140+ for shard in param_plan .sharding_spec .shards :
141+ if shard .placement .rank () is None or shard .placement .rank () < 0 :
142+ msg = f"Rank is not assigned for shard { shard } "
143+ logging .error (msg )
144+ raise PlannerError (
145+ error_type = PlannerErrorType .INVALID_RANK_ASSIGNMENT ,
146+ message = msg ,
147+ )
148+ if shard .placement .rank () >= topology .world_size :
149+ msg = f"Shard { shard } has rank { shard .placement .rank ()} which is greater than world size { dist .get_world_size ()} ."
150+ logging .error (msg )
151+ raise PlannerError (
152+ error_type = PlannerErrorType .INVALID_RANK_ASSIGNMENT ,
153+ message = msg ,
154+ )
155+ else :
156+ msg = f"Sharding spec not found for { module_name } .{ param_name } "
157+ logging .warning (msg )
158+
159+
121160def extract_plan (
122161 search_space : List [ShardingOption ],
123162 loaded_sharding_options : Dict [int , ShardingOption ],
@@ -634,6 +673,8 @@ def plan(
634673 timeout_seconds = self ._timeout_seconds ,
635674 ),
636675 )
676+
677+ validate_rank_assignment (sharding_plan , self ._topology )
637678 return sharding_plan
638679 else :
639680 global_storage_capacity = reduce (
@@ -997,6 +1038,7 @@ def plan(
9971038 sharding_plan = to_sharding_plan (
9981039 best_plan , self ._topology_groups [group ]
9991040 )
1041+ validate_rank_assignment (sharding_plan , self ._topology_groups [group ])
10001042 best_plans .append (sharding_plan )
10011043
10021044 end_time = perf_counter ()
@@ -1090,4 +1132,5 @@ def plan(
10901132 + last_planner_error_info ,
10911133 )
10921134
1093- return _merge_plans (best_plans )
1135+ sharding_plan = _merge_plans (best_plans )
1136+ return sharding_plan
0 commit comments