Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adithyare/mamba dpo #374

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
mamba_hybrid: False

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ model:
output_original_text: True # needed for the proper metrics support

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -53,7 +53,7 @@ def main(cfg) -> None:
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTDPOModel,
MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel,
cfg.model,
trainer,
strict=True,
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -127,7 +127,7 @@ def main(cfg) -> None:
cfg.model.precision = cfg.trainer.precision

ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
MambaSFTModel if cfg.model.mamba_hybrid else GPTSFTModel,
cfg,
trainer,
strict=True,
Expand Down
18 changes: 18 additions & 0 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory

def pad_sequence_to_max(sequences, max_len, padding_value=0):
# Then, pad further to match `max_len`
if sequences.size(1) > max_len:
raise RuntimeError("max len has to be > seq len")
elif sequences.size(1) <= max_len:
pad_size = max_len - sequences.size(1)
padding = torch.full((sequences.size(0), pad_size), padding_value)
padded_sequences = torch.cat([sequences, padding], dim=1)
return padded_sequences

def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
chosen_tokens = [item["chosen"] for item in batch]
Expand Down Expand Up @@ -317,6 +326,15 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
if self.model.cfg.mamba_hybrid:
max_seq_len = max([batch['chosen'].size(-1), batch['rejected'].size(-1), batch['chosen_labels'].size(-1), batch['rejected_labels'].size(-1)])
max_seq_len = torch.tensor(max_seq_len, device=torch.cuda.current_device())
torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX)
max_seq_len = ((max_seq_len.item() + 255) // 256) * 256
batch["chosen"] = pad_sequence_to_max(batch["chosen"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["chosen_labels"] = pad_sequence_to_max(batch["chosen_labels"], max_seq_len, padding_value=-100)
batch["rejected"] = pad_sequence_to_max(batch["rejected"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["rejected_labels"] = pad_sequence_to_max(batch["rejected_labels"], max_seq_len, padding_value=-100)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
6 changes: 6 additions & 0 deletions nemo_aligner/models/nlp/gpt/gpt_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy
from nemo.collections.nlp.modules.common.text_generation_utils import (
Expand Down Expand Up @@ -225,3 +226,8 @@ def finish_inference(self):
self._restore_activation_checkpointing_args()
self._restore_sequence_parallelism_args()
set_train(self)


class MambaSFTModel(MegatronMambaModel, GPTSFTModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
7 changes: 7 additions & 0 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from functools import partial

import torch
from megatron.core import parallel_state
from megatron.core.models.mamba import MambaModel
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import divide
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_iterator_k_split,
Expand Down Expand Up @@ -460,3 +463,7 @@ def get_ref_policy_logprobs(self, batch):

# return in GPU, trainer needs to move to cpu
return ref_log_probs

class MegatronMambaDPOModel(MegatronMambaModel, MegatronGPTDPOModel): # @adithyare inherence order matters
arendu marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
5 changes: 3 additions & 2 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import torch
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.num_microbatches_calculator import reconfigure_microbatch_calculator
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator as reconfigure_microbatch_calculator
from omegaconf import DictConfig, OmegaConf
from torch.masked import as_masked_tensor

Expand Down Expand Up @@ -122,7 +122,8 @@ def load_checkpoint_model_config(restore_path):
return OmegaConf.load(cfg_path)

with tempfile.TemporaryDirectory() as tmpdir:
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True)
members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name)
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members)
cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))

return cfg
Expand Down
Loading