127127    set_seed ,
128128    speed_metrics ,
129129)
130+ import  torch_xla .distributed .parallel_loader  as  pl 
130131from  .training_args  import  OptimizerNames , ParallelMode , TrainingArguments 
131132from  .utils  import  (
132133    ADAPTER_CONFIG_NAME ,
@@ -264,6 +265,7 @@ def _get_fsdp_ckpt_kwargs():
264265
265266logger  =  logging .get_logger (__name__ )
266267
268+ NUM_SLICE = int (os .getenv ('NUM_SLICE' , 1 ))
267269
268270# Name of the files used for checkpointing 
269271TRAINING_ARGS_NAME  =  "training_args.bin" 
@@ -381,6 +383,7 @@ def __init__(
381383            args  =  TrainingArguments (output_dir = output_dir )
382384        self .args  =  args 
383385        # Seed must be set before instantiating the model when using model 
386+         set_seed (self .args .seed )
384387        enable_full_determinism (self .args .seed ) if  self .args .full_determinism  else  set_seed (self .args .seed )
385388        self .hp_name  =  None 
386389        self .deepspeed  =  None 
@@ -679,6 +682,18 @@ def __init__(
679682            # Tensor axis is just a placeholder where it will not be used in FSDPv2. 
680683            num_devices  =  xr .global_runtime_device_count ()
681684            xs .set_global_mesh (xs .Mesh (np .array (range (num_devices )), (num_devices , 1 ), axis_names = ("fsdp" , "tensor" )))
685+             if  NUM_SLICE == 1 :
686+                 mesh_shape  =  (num_devices , 1 )
687+                 device_ids  =  np .array (range (num_devices ))
688+                 # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on. 
689+                 mesh  =  xs .Mesh (device_ids , mesh_shape , ('fsdp' , 'tensor' ))
690+                 xs .set_global_mesh (mesh )
691+             else :
692+                 dcn_axis  =  NUM_SLICE 
693+                 ici_mesh_shape  =  (1 , num_devices  //  dcn_axis , 1 )
694+                 dcn_mesh_shape  =  (dcn_axis , 1 , 1 )
695+                 mesh  =  xs .HybridMesh (ici_mesh_shape = ici_mesh_shape , dcn_mesh_shape = dcn_mesh_shape , axis_names = ('dcn' , 'fsdp' , 'tensor' ))
696+                 xs .set_global_mesh (mesh )
682697
683698    def  _activate_neftune (self , model ):
684699        r""" 
@@ -877,6 +892,24 @@ def get_train_dataloader(self) -> DataLoader:
877892            dataloader_params ["worker_init_fn" ] =  seed_worker 
878893            dataloader_params ["prefetch_factor" ] =  self .args .dataloader_prefetch_factor 
879894
895+ 
896+         if  is_torch_xla_available ():
897+             torch_dataloader  =  DataLoader (train_dataset , ** dataloader_params )
898+             device  =  xm .xla_device ()
899+             if  NUM_SLICE == 1 :
900+                 mp_device_loader  =  pl .MpDeviceLoader (
901+                     torch_dataloader ,
902+                     device ,
903+                     input_sharding = xs .ShardingSpec (xs .get_global_mesh (), ("fsdp" , None )),
904+                 )
905+             else :
906+                 mp_device_loader  =  pl .MpDeviceLoader (
907+                     torch_dataloader ,
908+                     device ,
909+                     input_sharding = xs .ShardingSpec (xs .get_global_mesh (), (("dcn" , "fsdp" ), None )),
910+                 )
911+             return  mp_device_loader 
912+ 
880913        return  self .accelerator .prepare (DataLoader (train_dataset , ** dataloader_params ))
881914
882915    def  _get_eval_sampler (self , eval_dataset : Dataset ) ->  Optional [torch .utils .data .Sampler ]:
@@ -1681,7 +1714,6 @@ def _wrap_model(self, model, training=True, dataloader=None):
16811714                    # Transformer layer class to wrap 
16821715                    transformer_layer_cls = transformer_cls_to_wrap ,
16831716                )
1684-             fsdp_kwargs  =  self .args .xla_fsdp_config 
16851717            if  self .args .fsdp_config ["xla_fsdp_grad_ckpt" ]:
16861718                if  model .config .use_cache :
16871719                    logger .warning_once (
@@ -1709,7 +1741,11 @@ def shard_output(output, mesh):
17091741
17101742                    if  real_output  is  None :
17111743                        raise  ValueError ("Something went wrong, the output of the model shouldn't be `None`" )
1712-                     xs .mark_sharding (real_output , mesh , ("fsdp" , None , None ))
1744+ 
1745+                     if  NUM_SLICE == 1 :
1746+                         xs .mark_sharding (real_output , mesh , ("fsdp" , None , None ))
1747+                     else :
1748+                         xs .mark_sharding (real_output , mesh , (("dcn" , "fsdp" ), None , None ))
17131749
17141750                self .model  =  model  =  FSDPv2 (
17151751                    model ,
@@ -1718,10 +1754,12 @@ def shard_output(output, mesh):
17181754                    auto_wrapper_callable = auto_wrapper_callable ,
17191755                )
17201756            else :
1757+                 fsdp_kwargs  =  self .args .xla_fsdp_config 
17211758                self .model  =  model  =  FSDP (
17221759                    model ,
17231760                    auto_wrap_policy = auto_wrap_policy ,
17241761                    auto_wrapper_callable = auto_wrapper_callable ,
1762+                     reshard_after_forward = False ,
17251763                    ** fsdp_kwargs ,
17261764                )
17271765
@@ -1854,6 +1892,7 @@ def train(
18541892                # Disable progress bars when uploading models during checkpoints to avoid polluting stdout 
18551893                hf_hub_utils .disable_progress_bars ()
18561894                return  inner_training_loop (
1895+                     batch_size = self ._train_batch_size ,
18571896                    args = args ,
18581897                    resume_from_checkpoint = resume_from_checkpoint ,
18591898                    trial = trial ,
@@ -1863,6 +1902,7 @@ def train(
18631902                hf_hub_utils .enable_progress_bars ()
18641903        else :
18651904            return  inner_training_loop (
1905+                 batch_size = self ._train_batch_size ,
18661906                args = args ,
18671907                resume_from_checkpoint = resume_from_checkpoint ,
18681908                trial = trial ,
@@ -1892,8 +1932,8 @@ def _inner_training_loop(
18921932        logger .debug (f"Currently training with a batch size of: { self ._train_batch_size }  )
18931933        # Data loader and number of training steps 
18941934        train_dataloader  =  self .get_train_dataloader ()
1895-         if  self .is_fsdp_xla_v2_enabled :
1896-             train_dataloader  =  tpu_spmd_dataloader (train_dataloader )
1935+         #  if self.is_fsdp_xla_v2_enabled:
1936+         #      train_dataloader = tpu_spmd_dataloader(train_dataloader)
18971937
18981938        # Setting up training control variables: 
18991939        # number of training epochs: num_train_epochs 
@@ -4454,3 +4494,4 @@ def _fsdp_qlora_plugin_updates(self):
44544494                fsdp_plugin .set_mixed_precision (
44554495                    self .model .hf_quantizer .quantization_config .bnb_4bit_quant_storage , override = True 
44564496                )
4497+ 
0 commit comments