Skip to content

Commit

Permalink
minor updates trainer almost complete
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Dec 1, 2024
1 parent 8fb7c29 commit 24f7d6f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 38 deletions.
41 changes: 32 additions & 9 deletions rankers/datasets/flaxloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
from flax.training.common_utils import shard
import jax.numpy as jnp

def process_distributed_tokens(tokens):
if not isinstance(tokens, dict):
return shard(jnp.array(tokens))
return {k: shard(jnp.array(v)) for k, v in tokens.items()}

def process_tokens(tokens):
return {k: shard(jnp.array(v)) for k, v in tokens.items()}
if not isinstance(tokens, dict):
return jnp.array(tokens)
return {k: jnp.array(v) for k, v in tokens.items()}


class FlaxDotDataCollator:
Expand All @@ -15,11 +21,13 @@ def __init__(
special_mask=False,
q_max_length=30,
d_max_length=200,
distributed=False,
) -> None:
self.tokenizer = tokenizer
self.q_max_length = q_max_length
self.d_max_length = d_max_length
self.special_mask = special_mask
self.process_tokens = process_distributed_tokens if distributed else process_tokens

def __call__(self, batch) -> dict:
batch_queries = []
Expand Down Expand Up @@ -50,10 +58,10 @@ def __call__(self, batch) -> dict:
)

return {
"queries": process_tokens(dict(tokenized_queries)),
"docs_batch": process_tokens(dict(tokenized_docs)),
"queries": self.process_tokens(dict(tokenized_queries)),
"docs_batch": self.process_tokens(dict(tokenized_docs)),
"labels": (
shard(jnp.array(np.array(batch_scores)))
self.process_tokens(jnp.array(np.array(batch_scores)))
if len(batch_scores) > 0
else None
),
Expand All @@ -66,10 +74,12 @@ def __init__(
tokenizer,
q_max_length=30,
d_max_length=200,
distributed=False,
) -> None:
self.tokenizer = tokenizer
self.q_max_length = q_max_length
self.d_max_length = d_max_length
self.process_tokens = process_distributed_tokens if distributed else process_tokens

def __call__(self, batch) -> dict:
batch_queries = []
Expand All @@ -93,7 +103,7 @@ def __call__(self, batch) -> dict:
return {
"sequences": process_tokens(dict(tokenized_sequences)),
"labels": (
shard(jnp.array(np.array(batch_scores)))
self.process_tokens(np.array(batch_scores))
if len(batch_scores) > 0
else None
),
Expand All @@ -109,9 +119,14 @@ def _make_pos_pairs(texts) -> list:


class FlaxPairDataCollator:
def __init__(self, tokenizer, max_length=512) -> None:
def __init__(self,
tokenizer,
max_length=512,
distributed=False
) -> None:
self.tokenizer = tokenizer
self.max_length = max_length
self.process_tokens = process_distributed_tokens if distributed else process_tokens

def __call__(self, batch) -> dict:
batch_queries = []
Expand Down Expand Up @@ -158,10 +173,12 @@ def __init__(
tokenizer,
prompt: Any,
max_length=512,
distributed=False,
) -> None:
self.tokenizer = tokenizer
self.prompt = prompt
self.max_length = max_length
self.process_tokens = process_distributed_tokens if distributed else process_tokens

def __call__(self, batch) -> dict:
batch_queries = []
Expand Down Expand Up @@ -189,18 +206,24 @@ def __call__(self, batch) -> dict:
return {
"sequences": process_tokens(dict(tokenized_sequences)),
"labels": (
shard(jnp.array(np.array(batch_scores)))
self.process_tokens(jnp.array(np.array(batch_scores)))
if len(batch_scores) > 0
else None
),
}


class FlaxPairPromptDataCollator:
def __init__(self, tokenizer, prompt: Any, max_length=512) -> None:
def __init__(self,
tokenizer,
prompt: Any,
max_length=512,
distributed=False
) -> None:
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt = prompt
self.process_tokens = process_distributed_tokens if distributed else process_tokens

def __call__(self, batch) -> dict:
batch_queries = []
Expand Down Expand Up @@ -234,7 +257,7 @@ def __call__(self, batch) -> dict:
return {
"sequences": process_tokens(dict(tokenized_sequences)),
"labels": (
shard(jnp.squeeze(jnp.array(np.array(batch_scores))))
self.process_tokens(jnp.squeeze(jnp.array(np.array(batch_scores))))
if len(batch_scores) > 0
else None
),
Expand Down
137 changes: 108 additions & 29 deletions rankers/train/flax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import sys
import logging
from transformers import Trainer
Expand All @@ -15,8 +16,9 @@
TrainerCallback,
TrainerControl,
PrinterCallback,
hf_hub_utils,
)
from transformers.trainer_utils import EvalLoopOutput, speed_metrics
from transformers.trainer_utils import EvalLoopOutput, speed_metrics, TrainOutput
from transformers.trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
Expand All @@ -27,6 +29,7 @@
from transformers.integrations import get_reporting_integration_callbacks
from functools import partial
import warnings
import jax
from jax import jit
import orbax
import jax.numpy as jnp
Expand All @@ -49,6 +52,12 @@
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

'''
TODO:
- Implement call_model_init
- Implement _load_best_model
'''


class FlaxContrastTrainer(Trainer):
"""Customized Trainer from Huggingface's Trainer"""
Expand All @@ -70,7 +79,7 @@ def __init__(
optimizers: Tuple[optax._src.base.Optimizer, optax._src.base.Scheduler] = (
None,
None,
), # TODO: Double check these base classes
)
):
if args is None:
output_dir = "tmp_trainer"
Expand Down Expand Up @@ -253,14 +262,13 @@ def create_optimizer_and_scheduler(self, num_training_steps=None):
@partial(jit, static_argnums=(0, 2))
def compute_loss(self, inputs, return_outputs=False):
labels = inputs.pop("labels")
pred, _ = self.state.apply_fn(**inputs)
outputs = self.state.apply_fn(**inputs)
# Save past state if it exists
pred = outputs.pred
if self.args.past_index >= 0:
self._past = pred[self.args.past_index]

loss = self.loss(pred, labels)

return (loss, pred) if return_outputs else loss
return (outputs.loss, outputs.pred) if return_outputs else outputs.loss

def compute_metrics(self, result_frame: pd.DataFrame):
from ir_measures import evaluator, RR
Expand Down Expand Up @@ -364,6 +372,14 @@ def get_num_trainable_parameters(self):
"""
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

def get_model_param_count(self, trainable_only: bool = False):
"""
Get the number of parameters of the model.
"""
if trainable_only:
return self.get_num_trainable_parameters()
return sum(p.numel() for p in self.model.parameters())

def get_learning_rates(self):
"""
Returns the learning rate of each parameter from self.optimizer.
Expand All @@ -374,6 +390,86 @@ def get_learning_rates(self):
)
return [group["lr"] for group in self.optimizer.param_groups]

def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

model = self.model

if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
weights_only_kwarg = {}
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(
best_model_path,
map_location="cpu",
**weights_only_kwarg,
)

load_result = model.load_state_dict(state_dict, strict=True)
else:
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapter = model.active_adapters[0]
if len(model.active_adapters) > 1:
logger.warning("Detected multiple active adapters, will only consider the first one")
else:
active_adapter = model.active_adapter

if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys

load_result = _IncompatibleKeys([], [])
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(
best_model_path,
map_location="cpu",
**weights_only_kwarg,
)
load_result = model.load_state_dict(state_dict, False)
if has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
):
load_result = load_sharded_checkpoint(
model, self.state.best_model_checkpoint
)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -595,11 +691,6 @@ def _inner_training_loop(

model = self.state.params

# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
Expand All @@ -619,7 +710,7 @@ def _inner_training_loop(
)
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
f" Number of trainable parameters = {self.get_model_param_count(model, trainable_only=True):,}"
)

self.state.epoch = 0
Expand All @@ -632,7 +723,7 @@ def _inner_training_loop(
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load(
self.state = FlaxTrainerState.load(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
)
self.compare_trainer_and_checkpoint_args(self.args, self.state)
Expand Down Expand Up @@ -761,10 +852,6 @@ def _inner_training_loop(
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)

# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
Expand All @@ -776,17 +863,14 @@ def _inner_training_loop(
tr_loss_step, grad = jax.value_and_grad(self.loss)(
self.state.params
)
self.state = self.state.apply_gradients(grads=pmean(grad, "batch"))
self.state = self.state.apply_gradients(grads=grad)

self.control = self.callback_handler.on_optimizer_step(
args, self.state, self.control
)

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, optax._src.base.Schedule):
self.lr_scheduler(self.state.step)
if self.lr_scheduler is not None:
self.lr_scheduler(self.state.step)

self.state.global_step += 1
self.state.epoch = (
Expand All @@ -805,9 +889,6 @@ def _inner_training_loop(
)

if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
break
if step < 0:
logger.warning(
Expand Down Expand Up @@ -904,10 +985,8 @@ def training_step(self, inputs: Dict[str, Union[jax.Array, Any]]) -> jax.Array:
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
inputs = self._prepare_inputs(inputs)

with self.compute_loss_context_manager():
loss = self.compute_loss(inputs)
loss = self.compute_loss(inputs)

del inputs

Expand Down

0 comments on commit 24f7d6f

Please sign in to comment.