Skip to content

Commit

Permalink
RuntimeTimer for the toolkit (PaddlePaddle#7913)
Browse files Browse the repository at this point in the history
* RuntimeTimer for the toolekit

* RuntimeTimer for the toolekit

* reformat

* fix timer and load checkpoints

* remove reset
  • Loading branch information
KB-Ding authored Jan 29, 2024
1 parent 9a31322 commit 6e0ac44
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
7 changes: 4 additions & 3 deletions paddlenlp/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from .argparser import *
from .training_args import *
from .compression_args import *
from .plugins.timer import *
from .trainer import *
from .trainer_callback import *
from .trainer_utils import *
from .trainer_compress import *
from .training_args_seq2seq import *
from .trainer_seq2seq import *
from .trainer_utils import *
from .training_args import *
from .training_args_seq2seq import *
30 changes: 28 additions & 2 deletions paddlenlp/trainer/plugins/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def __init__(self, name):

def start(self):
"""Start the timer."""
assert not self.started_, "timer has already started"
assert not self.started_, f"{self.name} timer has already started"
paddle.device.synchronize()
self.start_time = time.time()
self.started_ = True

def stop(self):
"""Stop the timers."""
assert self.started_, "timer is not started."
assert self.started_, f"{self.name} timer is not started."
paddle.device.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
Expand All @@ -65,6 +65,32 @@ def elapsed(self, reset=True):
return elapsed_


class RuntimeTimer:
"""A timer that can be dynamically adjusted during runtime."""

def __init__(self, name):
self.timer = _Timer(name)

def start(self, name):
"""Start the RuntimeTimer."""
self.timer.name = name
self.timer.start()

def stop(self):
"""Stop the RuntimeTimer."""
self.timer.stop()

def log(self):
"""Log, stop and reset the RuntimeTimer."""
runtime = self.timer.elapsed(reset=True)
if self.timer.started_ is True:
self.timer.stop()
self.timer.reset()

string = "[timelog] {}: {:.2f}s ({}) ".format(self.timer.name, runtime, time.strftime("%Y-%m-%d %H:%M:%S"))
return string


class Timers:
"""Group of timers."""

Expand Down
19 changes: 16 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
from ..utils.log import logger
from .argparser import strtobool
from .integrations import get_reporting_integration_callbacks
from .plugins.timer import get_timers, set_timers
from .plugins.timer import RuntimeTimer, get_timers, set_timers
from .plugins.unified_checkpoint import (
load_unified_checkpoint,
load_unified_optimizer,
Expand Down Expand Up @@ -304,6 +304,7 @@ def __init__(
if not args.skip_profile_timer:
set_timers()
self.timers = get_timers()
self.runtime_timer = RuntimeTimer("RuntimeTimer")

self.model_wrapped = model
self.model = model
Expand Down Expand Up @@ -508,6 +509,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. Only load model state dict.
"""
self.runtime_timer.start("checkpoint loading time")
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint

# Load potential model checkpoint
Expand All @@ -533,10 +535,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
safe_serialization=True,
)
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
self.runtime_timer.stop()
return

if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self._load_from_peft_checkpoint(resume_from_checkpoint)
self.runtime_timer.stop()
return

weight_name = PADDLE_WEIGHTS_NAME
Expand Down Expand Up @@ -586,6 +590,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):

elif resume_from_checkpoint is not None:
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
self.runtime_timer.stop()

def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
# In the sharded mode, should invoke _load_from_checkpoint after _wrap_model.
Expand Down Expand Up @@ -641,7 +646,6 @@ def train(

# memory metrics - must set up as early as possible
self._memory_tracker.start()

if not self.args.should_load_sharding_stage1_model:
self._load_from_checkpoint(resume_from_checkpoint)

Expand Down Expand Up @@ -697,6 +701,7 @@ def train(

if self.args.should_load_sharding_stage1_model:
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)

elif self.args.should_save_sharding_stage1_model:
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
Expand All @@ -720,6 +725,8 @@ def train(
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)

logger.info(f"{self.runtime_timer.log()}")

logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs}")
Expand Down Expand Up @@ -1268,6 +1275,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
paddle.device.cuda.synchronize()

self._save_checkpoint(model, metrics=metrics)
logger.info(f"{self.runtime_timer.log()}")
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _get_learning_rate(self):
Expand Down Expand Up @@ -2071,7 +2079,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

self.runtime_timer.start("checkpoint saving time")
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

Expand Down Expand Up @@ -2117,6 +2125,7 @@ def _save_checkpoint(self, model, metrics=None):
if self.do_grad_scaling:
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

self.runtime_timer.stop()
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2336,10 +2345,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
if checkpoint is None:
self.runtime_timer.stop()
return

if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
self.runtime_timer.stop()
return

opt_state_dict = None
Expand Down Expand Up @@ -2398,6 +2410,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self.runtime_timer.stop()

def log(self, logs: Dict[str, float], **kwargs) -> None:
"""
Expand Down

0 comments on commit 6e0ac44

Please sign in to comment.