Skip to content

Commit

Permalink
[feature] make timer optional and make reduce bucket size configurable (
Browse files Browse the repository at this point in the history
#549)

* [feature] make reduce bucket size configurable

* [feature] make timer optional
  • Loading branch information
ver217 authored Jun 27, 2024
1 parent 4582e8d commit 278cd75
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
1 change: 1 addition & 0 deletions opensora/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def parse_args(training=False):
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps")
parser.add_argument("--record-time", default=False, action="store_true", help="record time of each part")

return parser.parse_args()

Expand Down
8 changes: 6 additions & 2 deletions opensora/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from collections import OrderedDict
from collections.abc import Sequence
from itertools import repeat
from typing import Tuple
from typing import Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
from colossalai.cluster.dist_coordinator import DistCoordinator

# ======================================================
# Logging
Expand Down Expand Up @@ -358,11 +359,12 @@ def all_exists(paths):


class Timer:
def __init__(self, name, log=False):
def __init__(self, name, log=False, coordinator: Optional[DistCoordinator] = None):
self.name = name
self.start_time = None
self.end_time = None
self.log = log
self.coordinator = coordinator

@property
def elapsed_time(self):
Expand All @@ -374,6 +376,8 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.coordinator is not None:
self.coordinator.block_all()
torch.cuda.synchronize()
self.end_time = time.time()
if self.log:
Expand Down
4 changes: 3 additions & 1 deletion opensora/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from .misc import get_logger


def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size, reduce_bucket_size_in_m: int = 20):
if plugin == "zero2":
assert sp_size == 1, "Zero2 plugin does not support sequence parallelism"
plugin = LowLevelZeroPlugin(
stage=2,
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
)
set_data_parallel_group(dist.group.WORLD)
elif plugin == "zero2-seq":
Expand All @@ -30,6 +31,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
Expand Down
70 changes: 44 additions & 26 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
Expand Down Expand Up @@ -38,6 +39,7 @@ def main():
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
record_time = cfg.get("record_time", False)

# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
Expand Down Expand Up @@ -76,6 +78,7 @@ def main():
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
Expand Down Expand Up @@ -229,6 +232,21 @@ def main():
# 5. training loop
# =======================================================
dist.barrier()
timers = {}
timer_keys = [
"move_data",
"encode",
"mask",
"diffusion",
"backward",
"update_ema",
"reduce_loss",
]
for key in timer_keys:
if record_time:
timers[key] = Timer(key, coordinator=coordinator)
else:
timers[key] = nullcontext()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
Expand All @@ -245,13 +263,14 @@ def main():
) as pbar:
for step, batch in pbar:
timer_list = []
with Timer("move data") as move_data_t:
with timers["move_data"] as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list.append(move_data_t)
if record_time:
timer_list.append(move_data_t)

# == visual and text encoding ==
with Timer("encode") as encode_t:
with timers["encode"] as encode_t:
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
Expand All @@ -267,31 +286,31 @@ def main():
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
coordinator.block_all()
timer_list.append(encode_t)
if record_time:
timer_list.append(encode_t)

# == mask ==
with Timer("mask") as mask_t:
with timers["mask"] as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
coordinator.block_all()
timer_list.append(mask_t)
if record_time:
timer_list.append(mask_t)

# == video meta info ==
for k, v in batch.items():
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)

# == diffusion loss computation ==
with Timer("diffusion") as loss_t:
with timers["diffusion"] as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
coordinator.block_all()
timer_list.append(loss_t)
if record_time:
timer_list.append(loss_t)

# == backward & update ==
with Timer("backward") as backward_t:
with timers["backward"] as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
Expand All @@ -300,24 +319,24 @@ def main():
# update learning rate
if lr_scheduler is not None:
lr_scheduler.step()
coordinator.block_all()
timer_list.append(backward_t)
if record_time:
timer_list.append(backward_t)

# == update EMA ==
with Timer("update_ema") as ema_t:
with timers["update_ema"] as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
coordinator.block_all()
timer_list.append(ema_t)
if record_time:
timer_list.append(ema_t)

# == update log info ==
with Timer("reduce_loss") as reduce_loss_t:
with timers["reduce_loss"] as reduce_loss_t:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
coordinator.block_all()
timer_list.append(reduce_loss_t)
if record_time:
timer_list.append(reduce_loss_t)

# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
Expand Down Expand Up @@ -376,12 +395,11 @@ def main():
global_step + 1,
save_dir,
)

log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)
coordinator.block_all()
if record_time:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)

sampler.reset()
start_step = 0
Expand Down

0 comments on commit 278cd75

Please sign in to comment.