Skip to content

Commit

Permalink
fix world_size (#2801)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Dec 30, 2024
1 parent e7bbf0a commit 132adc7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 12 deletions.
8 changes: 2 additions & 6 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,10 @@ def __post_init__(self):
self._init_ckpt_dir()
self._init_custom_register()
self._init_model_kwargs()
self.rank, self.local_rank, world_size, self.local_world_size = get_dist_setting()
# The Seq2SeqTrainingArguments has a property called world_size, which cannot be assigned a value.
try:
self.world_size = world_size
except AttributeError:
pass
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '
f'world_size: {world_size}, local_world_size: {self.local_world_size}')
f'world_size: {self.global_world_size}, local_world_size: {self.local_world_size}')
assert len(self.adapters) <= 1, f'args.adapters: {self.adapters}'
ModelArguments.__post_init__(self)
QuantizeArguments.__post_init__(self)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _init_eval_strategy(self):
def __post_init__(self):
self._init_output_dir()
if self.average_tokens_across_devices is None:
self.average_tokens_across_devices = self.world_size > 1
self.average_tokens_across_devices = self.global_world_size > 1
if self.metric_for_best_model is None:
self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss'
if self.greater_is_better is None:
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
if self.jsonl_writer:
self.jsonl_writer.append(data)
else:
is_dist = args.world_size > 1 and dist.is_initialized()
is_dist = args.global_world_size > 1 and dist.is_initialized()
if is_dist:
val_dataset = val_dataset.shard(args.world_size, args.rank, contiguous=True)
val_dataset = val_dataset.shard(args.global_world_size, args.rank, contiguous=True)
val_dataset = list(val_dataset)
labels_list = [InferRequest.remove_response(data['messages']) for data in val_dataset]

Expand All @@ -197,7 +197,7 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
data = {'response': response, 'logprobs': resp.choices[0].logprobs, **data}
result_list.append(data)
if is_dist:
total_result_list = [None for _ in range(args.world_size)] if args.rank == 0 else None
total_result_list = [None for _ in range(args.global_world_size)] if args.rank == 0 else None
dist.gather_object(result_list, total_result_list)
result_list = total_result_list and list(chain.from_iterable(total_result_list))

Expand Down
4 changes: 2 additions & 2 deletions swift/utils/torchacc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir):
'shard_metadata': self_model._get_underlay_model().get_shard_metadata(),
}
if isinstance(model, PeftModel):
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.world_size}-adapter_model.bin')
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin')
else:
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.world_size}-pytorch_model.bin')
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin')
xm.save(ckpt, ckpt_path, master_only=False)
# Make sure all ranks have saved checkpoints
xm.rendezvous('save_full_checkpoints')
Expand Down

0 comments on commit 132adc7

Please sign in to comment.