Skip to content

Commit

Permalink
fix per_gpu args (#1809)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Nov 11, 2024
1 parent a5cfaaf commit d5ac1fe
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions mindnlp/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_balanced_memory(
break # only one device

module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
per_device = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)

# We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
# slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
Expand Down Expand Up @@ -251,7 +251,7 @@ def get_balanced_memory(
leaves = get_module_leaves(module_sizes)
mean_leaves = int(sum(module_sizes[n] for n in leaves) / max(len(leaves), 1))
buffer = int(1.25 * max(buffer, mean_leaves))
per_gpu += buffer
per_device += buffer

# Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
gpus_idx_list = list(
Expand All @@ -261,7 +261,7 @@ def get_balanced_memory(
)
# The last device is left with max_memory just in case the buffer is not enough.
for idx in gpus_idx_list[:-1]:
max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_device, max_memory[idx])

if low_zero:
min_zero = max(0, module_sizes[""] - sum(max_memory[i] for i in range(1, num_devices)))
Expand Down
8 changes: 4 additions & 4 deletions mindnlp/engine/train_args/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,8 @@ def __str__(self):

# Remove deprecated arguments. That code should be removed once
# those deprecated arguments are removed from TrainingArguments. (TODO: v5)
del self_as_dict["per_gpu_train_batch_size"]
del self_as_dict["per_gpu_eval_batch_size"]
del self_as_dict["per_device_train_batch_size"]
del self_as_dict["per_device_eval_batch_size"]

self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}

Expand Down Expand Up @@ -1193,7 +1193,7 @@ def n_device(self):
@property
def train_batch_size(self) -> int:
"""
The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
The actual batch size for training (may differ from `per_device_train_batch_size` in distributed training).
"""
per_device_batch_size = self.per_device_train_batch_size
train_batch_size = per_device_batch_size * max(1, self.n_device)
Expand All @@ -1202,7 +1202,7 @@ def train_batch_size(self) -> int:
@property
def eval_batch_size(self) -> int:
"""
The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
The actual batch size for evaluation (may differ from `per_device_eval_batch_size` in distributed training).
"""
per_device_batch_size = self.per_device_eval_batch_size
eval_batch_size = per_device_batch_size * max(1, self.n_device)
Expand Down

0 comments on commit d5ac1fe

Please sign in to comment.