diff --git a/docs/source/asr/api.rst b/docs/source/asr/api.rst index c99d92c0371a..a35ea49ea2c4 100644 --- a/docs/source/asr/api.rst +++ b/docs/source/asr/api.rst @@ -276,6 +276,21 @@ RNNT Decoding :show-inheritance: :members: +TDT Decoding +~~~~~~~~~~~~~ + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer + :show-inheritance: + :members: + Hypotheses ~~~~~~~~~~ diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index c01f2363db75..e0bd47bb8ce0 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -55,6 +55,20 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + + Returns: + list: A list of packed hypotheses in tensor format. + """ for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long) @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: def _states_to_device(dec_state, device='cpu'): + """ + Transfers decoder states to the specified device. + + This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda'). + + Args: + dec_state (Tensor): The decoder states to be transferred. + device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'. + + Returns: + Tensor: The decoder states on the specified device. + """ if torch.is_tensor(dec_state): dec_state = dec_state.to(device) @@ -106,7 +132,8 @@ class BeamRNNTInfer(Typing): however the time required for the search also grows steadily. `tsd` - time synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. @@ -114,7 +141,8 @@ class BeamRNNTInfer(Typing): good results. This also requires greater memory to execute. `alsd` - alignment-length synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing): For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. `maes` = modified adaptive expansion searcn. Please refer to the paper: - [Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505) + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing): and affects the speed of inference since large values will perform large beam search in the next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing): preserve_alignments: Bool flag which preserves the history of alignments generated during beam decoding (sample). When set to true, the Hypothesis will contain - the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1). + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu return lm_score, next_state def set_decoding_type(self, decoding_type: str): - - # Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ # TOKEN_OFFSET for BPE-based models if decoding_type == 'subword': from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str): @dataclass class BeamRNNTInferConfig: + """ + Beam RNNT Inference config. + """ + beam_size: int search_type: str = 'default' score_norm: bool = True diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index da280a0c6b3c..d3a63467c485 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -23,7 +23,7 @@ import torch from omegaconf import OmegaConf -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding, tdt_beam_decoding from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMixin from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer @@ -67,15 +67,15 @@ class AbstractRNNTDecoding(ConfidenceMixin): rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated. Can take the following values - "char" for character/subword time stamps, "word" for word level - time stamps, "segment" for segment level time stamps and "all" (default), for character, - word and segment level time stamps. + time stamps, "segment" for segment level time stamps and "all" (default), for character, word and + segment level time stamps. word_seperator: Str token representing the seperator between words. segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary - for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -106,8 +106,8 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -179,23 +179,23 @@ class AbstractRNNTDecoding(ConfidenceMixin): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep + this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the - next step. + and affects the speed of inference since large values will perform large beam search in the next + step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. The default (2.3) is selected from the paper. It performs a comparison - (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set - and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin - of additional tokens which can be potential candidates for expansion apart from the "most likely" + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions - (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). - This is a hyper parameter to be experimentally tuned on a validation set. + (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is + a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -234,8 +234,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None and self.big_blank_durations != []: raise ValueError("duration and big_blank_durations can't both be not None") - if self.cfg.strategy not in ['greedy', 'greedy_batch']: - raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") + if self.cfg.strategy not in ['greedy', 'greedy_batch', 'beam', 'maes']: + raise ValueError( + "currently only greedy, greedy_batch, beam and maes inference is supported for TDT models" + ) if ( self.big_blank_durations is not None and self.big_blank_durations != [] @@ -386,20 +388,32 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'beam': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='default', - score_norm=self.cfg.beam.get('score_norm', True), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ) + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) elif self.cfg.strategy == 'tsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -413,7 +427,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'alsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -427,26 +440,44 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'maes': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='maes', - score_norm=self.cfg.beam.get('score_norm', True), - maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), - maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), - maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), - maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), - ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), - hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), - hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), - ) - + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), + hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), + hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.3), + ) else: raise ValueError( @@ -728,6 +759,15 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: raise NotImplementedError() def update_joint_fused_batch_size(self): + """ " + Updates the fused batch size for the joint module if applicable. + + If `joint_fused_batch_size` is set, verifies that the joint module has + the required `set_fused_batch_size` and `set_fuse_loss_wer` functions. + If present, updates the batch size; otherwise, logs a warning. + + If `joint_fused_batch_size` is <= 0, disables fused batch processing. + """ if self.joint_fused_batch_size is None: # do nothing and let the Joint itself handle setting up of the fused batch return @@ -754,6 +794,21 @@ def update_joint_fused_batch_size(self): self.decoding.joint.set_fuse_loss_wer(False) def compute_rnnt_timestamps(self, hypothesis: Hypothesis, timestamp_type: str = "all"): + """ + Computes character, word, and segment timestamps for an RNN-T hypothesis. + + This function generates timestamps for characters, words, and segments within + a hypothesis sequence. The type of timestamps computed depends on `timestamp_type`, + which can be 'char', 'word', 'segment', or 'all'. + + Args: + hypothesis (Hypothesis): Hypothesis. + timestamp_type (str): Type of timestamps to compute. Options are 'char', 'word', 'segment', or 'all'. + Defaults to 'all'. + + Returns: + Hypothesis: The updated hypothesis with computed timestamps for characters, words, and/or segments. + """ assert timestamp_type in ['char', 'word', 'segment', 'all'] # Unpack the temporary storage @@ -890,7 +945,7 @@ def _compute_offsets( # Construct the start and end indices brackets end_indices = np.asarray(token_repetitions).cumsum() - start_indices = np.concatenate(([int(start_index)], end_indices[:-1])) + start_indices = np.concatenate(([start_index], end_indices[:-1])) # Process the TxU dangling alignment tensor, containing pairs of (logits, label) alignment_labels = [al_logits_labels for al_logits_labels in hypothesis.text[1]] @@ -953,8 +1008,8 @@ def _refine_timestamps_tdt( # Check if token is a punctuation mark # If so, set its start and end offset as start and end of the previous token - # This is done because there was observed a behaviour, when punctuation marks are predicted long - # after preceding token (i.e. after silence) + # This is done because there was observed a behaviour, when punctuation marks are + # predicted long after preceding token (i.e. after silence) if offset['char'][0] in supported_punctuation and i > 0: encoded_char_offsets[i]['start_offset'] = offset['start_offset'] = char_offsets[i - 1]['end_offset'] encoded_char_offsets[i]['end_offset'] = offset['end_offset'] = offset['start_offset'] @@ -1114,7 +1169,8 @@ def _get_segment_offsets( offsets: A list of dictionaries, each containing "word", "start_offset" and "end_offset". segments_delimiter_tokens: List containing tokens representing the seperator(s) between segments. supported_punctuation: Set containing punctuation marks in the vocabulary. - segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain text. + segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain + text. Returns: A list of dictionaries containing the segment offsets. Each item contains "segment", "start_offset" and "end_offset". @@ -1242,9 +1298,10 @@ class RNNTDecoding(AbstractRNNTDecoding): exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word - confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + confidence. + Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1331,7 +1388,7 @@ class RNNTDecoding(AbstractRNNTDecoding): and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, @@ -1339,8 +1396,7 @@ class RNNTDecoding(AbstractRNNTDecoding): next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the - expansions. - The default (2.3) is selected from the paper. It performs a comparison + expansions. The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the "most likely" @@ -1382,7 +1438,9 @@ def __init__( supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('char') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1498,8 +1556,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for - forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -1530,8 +1588,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be - calculated and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1602,7 +1660,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): at increased cost to execution time. alsd_max_target_len: optional int or float, determines the potential maximum target sequence - length. If an integer is provided, it can decode sequences of that particular maximum length. + length.If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1622,16 +1680,15 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): and affects the speed of inference since large values will perform large beam search in the next step. - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when - computing the expansions. The default (2.3) is selected from the paper. It performs a - comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the - Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore - provides a margin of additional tokens which can be potential candidates for expansion - apart from the "most likely" candidate. Lower values will reduce the number of expansions - (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher - values will increase the number of expansions (by reducing pruning-by-value, thereby - reducing speed but potentially improving accuracy). This is a hyper parameter to be - experimentally tuned on a validation set. + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of + expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving + accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1658,7 +1715,9 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('subword') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1759,8 +1818,8 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) else: logging.warning( - "Ignoring request for lang output in hypotheses since the model does not use an aggregate\ - tokenizer" + "Ignoring request for lang output in hypotheses since the model does not use an aggregate \ + tokenizer" ) return hypotheses @@ -1768,6 +1827,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp @dataclass class RNNTDecodingConfig: + """ + RNNT Decoding config + """ + model_type: str = "rnnt" # one of "rnnt", "multiblank" or "tdt" strategy: str = "greedy_batch" @@ -1825,4 +1888,8 @@ class RNNTDecodingConfig: @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig): + """ + RNNT BPE Decoding Config + """ + pass diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index f9cf368fe405..bd169d0d224e 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -49,7 +49,20 @@ def pack_hypotheses( hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor, ) -> List[rnnt_utils.Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + Returns: + list: A list of packed hypotheses in tensor format. + """ if hasattr(logitlen, 'cpu'): logitlen_cpu = logitlen.to('cpu') else: @@ -578,7 +591,8 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls to prediction network (with maximum possible batch size), which makes it especially useful for scaling the prediction network. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( @@ -1169,6 +1183,10 @@ def _greedy_decode_masked( class ExportedModelGreedyBatchedRNNTInfer: + """ + Exported Model Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = None): self.encoder_model_path = encoder_model self.decoder_joint_model_path = decoder_joint_model @@ -1344,9 +1362,25 @@ def _setup_blank_index(self): raise NotImplementedError() def run_encoder(self, audio_signal, length): + """ + Runs encoder network: + + Args: + audio_signal: audio signal + length: audio length + """ raise NotImplementedError() def run_decoder_joint(self, enc_logits, targets, target_length, *states): + """ + Runs decoder joint networks. + + Args: + enc_logits: encoder logits + targets: targets + target_length: target length + states: states + """ raise NotImplementedError() def _get_initial_states(self, batchsize): @@ -1354,6 +1388,10 @@ def _get_initial_states(self, batchsize): class ONNXGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + ONNX Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = 10): super().__init__( encoder_model=encoder_model, @@ -1433,7 +1471,8 @@ def _setup_blank_index(self): self._blank_index = log_probs.shape[-1] - 1 # last token of vocab size is blank token logging.info( - f"Enc-Dec-Joint step was evaluated, blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" + f"Enc-Dec-Joint step was evaluated, \ + blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" ) def run_encoder(self, audio_signal, length): @@ -1512,6 +1551,10 @@ def _get_initial_states(self, batchsize): class TorchscriptGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + Torchscript Greedy Batched RNNT Infer + """ + def __init__( self, encoder_model: str, @@ -2336,6 +2379,8 @@ def _greedy_decode_masked( @dataclass class GreedyRNNTInferConfig: + """Greedy RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2354,6 +2399,8 @@ def __post_init__(self): @dataclass class GreedyBatchedRNNTInferConfig: + """Greedy Batched RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2708,7 +2755,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py new file mode 100644 index 000000000000..908fc1c13d19 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -0,0 +1,800 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import pack_hypotheses +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType +from nemo.utils import logging + +try: + import kenlm + + KENLM_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + KENLM_AVAILABLE = False + + +class BeamTDTInfer(Typing): + """ + Beam search implementation for Token-andDuration Transducer (TDT) models. + + Sequence level beam decoding or batched-beam decoding, performed auto-repressively + depending on the search type chosen. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + durations: list of duration values from TDT model. + + beam_size: number of beams for beam search. Must be a positive integer >= 1. + If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. + + search_type: str representing the type of beam search to perform. + Must be one of ['beam', 'maes']. + + Algorithm used: + + `default` - basic beam search strategy. Larger beams generally result in better decoding, + however the time required for the search also grows steadily. + + `maes` = modified adaptive expansion search. Please refer to the paper: + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) + + Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the + number of expansions (for tokens) required per timestep. The number of expansions can usually + be constrained to 1 or 2, and in most cases 2 is sufficient. + + This beam search technique can possibly obtain superior WER while sacrificing some evaluation time. + + score_norm: bool, whether to normalize the scores of the log probabilities. + + return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), + or return all N hypothesis (sorted with best score first). The container class changes based + this flag - + When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + + # The following arguments are specific to the chosen `search_type` + + # mAES flags + maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient. int > 1. + + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 + in order to reduce expensive beam search cost later. int >= 0. + + maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. + Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, + and affects the speed of inference since large values will perform large beam search in the next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. + Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed + but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, + thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally + tuned on a validation set. + + softmax_temperature: Scales the logits of the joint prior to computing log_softmax. + + preserve_alignments: Bool flag which preserves the history of alignments generated during + beam decoding (sample). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) + + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. + U is the number of target tokens for the current timestep Ti. + + NOTE: `preserve_alignments` is an invalid argument for any `search_type` + other than basic beam search. + + ngram_lm_model: str + The path to the N-gram LM. + ngram_lm_alpha: float + Alpha weight of N-gram LM. + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "partial_hypotheses": [NeuralType(elements_type=HypothesisType(), optional=True)], # must always be last + } + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + durations: list, + beam_size: int, + search_type: str = 'default', + score_norm: bool = True, + return_best_hypothesis: bool = True, + maes_num_steps: int = 2, + maes_prefix_alpha: int = 1, + maes_expansion_gamma: float = 2.3, + maes_expansion_beta: int = 2, + softmax_temperature: float = 1.0, + preserve_alignments: bool = False, + ngram_lm_model: Optional[str] = None, + ngram_lm_alpha: float = 0.3, + ): + self.joint = joint_model + self.decoder = decoder_model + self.durations = durations + + self.token_offset = 0 + self.search_type = search_type + self.blank = decoder_model.blank_idx + self.vocab_size = decoder_model.vocab_size + self.return_best_hypothesis = return_best_hypothesis + + self.beam_size = beam_size + self.score_norm = score_norm + self.max_candidates = beam_size + self.softmax_temperature = softmax_temperature + self.preserve_alignments = preserve_alignments + + if preserve_alignments: + raise ValueError("Alignment preservation has not been implemented.") + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + if self.preserve_alignments: + raise NotImplementedError("Preserving alignments is not implemented.") + + if search_type == "default": + if self.beam_size == 1: + logging.info( + """If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyTDTInfer or GreedyBatchedTDTInfer.""" + ) + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + raise NotImplementedError("`tsd` (Time Synchronous Decoding) has not been implemented.") + elif search_type == "alsd": + raise NotImplementedError("`alsd` (Alignment Length Synchronous Decoding) has not been implemented.") + elif search_type == "nsc": + raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") + elif search_type == "maes": + self.search_algorithm = self.modified_adaptive_expansion_search + else: + raise NotImplementedError( + f"The search type ({search_type}) supplied is not supported!\n" f"Please use one of : (default, maes)" + ) + + if self.search_type == 'maes': + self.maes_num_steps = int(maes_num_steps) + self.maes_prefix_alpha = int(maes_prefix_alpha) + self.maes_expansion_beta = int(maes_expansion_beta) + self.maes_expansion_gamma = float(maes_expansion_gamma) + + self.max_candidates += maes_expansion_beta + + if self.maes_prefix_alpha < 0: + raise ValueError("`maes_prefix_alpha` must be a positive integer.") + + if self.vocab_size < beam_size + maes_expansion_beta: + raise ValueError( + f"beam_size ({beam_size}) + expansion_beta ({maes_expansion_beta}) " + f"should be smaller or equal to vocabulary size ({self.vocab_size})." + ) + + if self.maes_num_steps < 1: + raise ValueError("`maes_num_steps` must be greater than 0.") + + try: + self.zero_duration_idx = self.durations.index(0) + except ValueError: + self.zero_duration_idx = None + self.min_non_zero_duration_idx = int( + np.argmin(np.ma.masked_where(np.array(self.durations) == 0, self.durations)) + ) + + if ngram_lm_model: + if search_type != "maes": + raise ValueError("For decoding with language model `maes` decoding strategy must be chosen.") + + if KENLM_AVAILABLE: + self.ngram_lm = kenlm.Model(ngram_lm_model) + self.ngram_lm_alpha = ngram_lm_alpha + else: + raise ImportError( + "KenLM package (https://github.com/kpu/kenlm) is not installed. " "Use ngram_lm_model=None." + ) + else: + self.ngram_lm = None + + @typecheck() + def __call__( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: tuple[list[Hypothesis | NBestHypotheses],] = None, + ) -> tuple[list[Hypothesis | NBestHypotheses],]: + """Perform general beam search. + + Args: + encoder_output: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + + Returns: + Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, + otherwise a list containing a single NBestHypotheses, which itself contains a list of + Hypothesis. This list is sorted such that the best hypothesis is the first element. + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.inference_mode(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + with tqdm( + range(encoder_output.size(0)), + desc='Beam search progress:', + total=encoder_output.size(0), + unit='sample', + ) as idx_gen: + + _p = next(self.joint.parameters()) + dtype = _p.dtype + + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] + + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) + + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis + + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) + + # Pack the result + if self.return_best_hypothesis: + best_hypothesis: Hypothesis = nbest_hyps[0] + else: + best_hypothesis: NBestHypotheses = NBestHypotheses(nbest_hyps) + hypotheses.append(best_hypothesis) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (hypotheses,) + + def default_beam_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """Default Beam search implementation for TDT models. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypoteses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + durations_beam_k = min(beam, len(self.durations)) + + # Initialize zero vector states. + decoder_state = self.decoder.initialize_state(encoder_outputs) + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Initialize hypothesis array with blank hypothesis. + start_hyp = Hypothesis( + score=0.0, y_sequence=[self.blank], dec_state=decoder_state, timestep=[-1], length=0, last_frame=0 + ) + kept_hyps = [start_hyp] + + for time_idx in range(int(encoded_lengths)): + # Retrieve hypotheses for current and future frames + hyps = [hyp for hyp in kept_hyps if hyp.last_frame == time_idx] # hypotheses for current frame + kept_hyps = [hyp for hyp in kept_hyps if hyp.last_frame > time_idx] # hypothesis for future frames + + # Loop over hypotheses of current frame + while len(hyps) > 0: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + # Update decoder state and get probability distribution over vocabulary and durations. + encoder_output = encoder_outputs[:, time_idx : time_idx + 1, :] # [1, 1, D] + decoder_output, decoder_state, _ = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] + logits = ( + self.joint.joint(encoder_output, decoder_output) / self.softmax_temperature + ) # [1, 1, 1, V + NUM_DURATIONS + 1] + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) # [V + 1] + durations_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) # [NUM_DURATIONS] + + # Proccess non-blank tokens + # Retrieve the top `beam_k` most probable tokens and the top `duration_beam_k` most probable durations. + # Then, select the top `beam_k` pairs of (token, duration) based on the highest combined probabilities. + # Note that indices are obtained in the flattened array. + logp_topks, logp_topk_idxs = logp[:-1].topk(beam_k, dim=-1) # topk of tokens without blank token + durations_logp_topks, durations_logp_topk_idxs = durations_logp.topk(durations_beam_k, dim=-1) + total_logp_topks, total_logp_topk_idxs = ( + torch.cartesian_prod(durations_logp_topks, logp_topks).sum(dim=-1).topk(beam_k, dim=-1) + ) + + # Loop over pairs of (token, duration) with highest combined log prob + for total_logp_topk, total_logp_topk_idx in zip(total_logp_topks, total_logp_topk_idxs): + # Restore indices from flattened array indices + token_idx = int(logp_topk_idxs[total_logp_topk_idx % beam_k]) + duration_idx = int(durations_logp_topk_idxs[total_logp_topk_idx // beam_k]) + + duration = self.durations[duration_idx] + # Construct hypothesis for non-blank token + new_hyp = Hypothesis( + score=float(max_hyp.score + total_logp_topk), # update score + y_sequence=max_hyp.y_sequence + [token_idx], # update hypothesis sequence + dec_state=decoder_state, # update decoder state + timestep=max_hyp.timestep + [time_idx + duration], # update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + + # Update current frame hypotheses if duration is zero and future frame hypotheses otherwise + if duration == 0: + hyps.append(new_hyp) + else: + kept_hyps.append(new_hyp) + + # Update future frames with blank tokens + # Note: blank token can have only non-zero duration + for duration_idx in durations_logp_topk_idxs: + duration_idx = int(duration_idx) + # If zero is the only duration in topk, switch to closest non-zero duration to continue + if duration_idx == self.zero_duration_idx: + if durations_logp_topk_idxs.shape[0] == 1: + duration_idx = self.min_non_zero_duration_idx + else: + continue + + duration = self.durations[duration_idx] + new_hyp = Hypothesis( + score=float(max_hyp.score + logp[self.blank] + durations_logp[duration_idx]), # update score + y_sequence=max_hyp.y_sequence[:], # no need to update sequence + dec_state=max_hyp.dec_state, # no need to update decoder state + timestep=max_hyp.timestep[:], # no need to update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + kept_hyps.append(new_hyp) + + # Merge duplicate hypotheses. + # If two consecutive blank tokens are predicted and their duration values sum up to the same number, + # it will produce two hypotheses with the same token sequence but different scores. + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + + if len(hyps) > 0: + # Keep those hypothesis that have scores greater than next search generation + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + # If enough hypotheses have scores greater than next search generation, + # stop beam search. + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + else: + # If there are no hypotheses in a current frame, + # keep only `beam` best hypotheses for the next search generation. + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + return self.sort_nbest(kept_hyps) + + def modified_adaptive_expansion_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """ + Modified Adaptive Exoansion Search algorithm for TDT models. + Based on/modified from https://ieeexplore.ieee.org/document/9250505. + Supports N-gram language model shallow fusion. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypotheses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_state = self.decoder.initialize_state( + torch.zeros(1, device=encoder_outputs.device, dtype=encoder_outputs.dtype) + ) # [L, B, H], [L, B, H] for LSTMS + + # Initialize first hypothesis for the beam (blank). + start_hyp = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=self.decoder.batch_select_state(beam_state, 0), + timestep=[-1], + length=0, + last_frame=0, + ) + init_tokens = [start_hyp] + + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis(init_tokens, cache) + state = beam_state[0] + + # Initialize first hypothesis for the beam (blank) for kept hypotheses + start_hyp_kept = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=state, + dec_out=[beam_decoder_output[0]], + timestep=[-1], + length=0, + last_frame=0, + ) + + kept_hyps = [start_hyp_kept] + + # Setup ngram LM: + if self.ngram_lm: + init_lm_state = kenlm.State() + self.ngram_lm.BeginSentenceWrite(init_lm_state) + start_hyp_kept.ngram_lm_state = init_lm_state + + for time_idx in range(encoded_lengths): + # Select current iteration hypotheses + hyps = [x for x in kept_hyps if x.last_frame == time_idx] + kept_hyps = [x for x in kept_hyps if x.last_frame > time_idx] + + if len(hyps) == 0: + continue + + beam_encoder_output = encoder_outputs[:, time_idx : time_idx + 1] # [1, 1, D] + # Perform prefix search to update hypothesis scores. + if self.zero_duration_idx is not None: + hyps = self.prefix_search( + sorted(hyps, key=lambda x: len(x.y_sequence), reverse=True), + beam_encoder_output, + prefix_alpha=self.maes_prefix_alpha, + ) + + list_b = [] # List that contains the blank token emissions + list_nb = [] # List that contains the non-zero duration non-blank token emissions + # Repeat for number of mAES steps + for n in range(self.maes_num_steps): + # Pack the decoder logits for all current hypotheses + beam_decoder_output = torch.stack([h.dec_out[-1] for h in hyps]) # [H, 1, D] + + # Extract the log probabilities + beam_logits = self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + + # Retrieve the top `max_candidades` most probable tokens. + # Then, select the top `max_candidates` pairs of (token, duration) + # based on the highest combined probabilities. + # Note that indices are obtained in flattened array. + beam_logp_topks, beam_idx_topks = beam_logp.topk(self.max_candidates, dim=-1) + beam_total_logp = (beam_duration_logp[:, :, None] + beam_logp_topks[:, None, :]).view( + len(hyps), -1 + ) # [B, MAX_CANDIDATES*DURATION_BEAM] + beam_total_logp_topks, beam_total_logp_topk_idxs = beam_total_logp.topk( + self.max_candidates, dim=-1 + ) # [B, MAX_CANDIDATES] + + # Prune hypothesis to obtain k expansions + beam_best_expansion_scores = beam_total_logp_topks.max(dim=-1, keepdim=True).values + beam_masks = beam_total_logp_topks >= beam_best_expansion_scores - self.maes_expansion_gamma + beam_kexpansions_idxs = [ + sum_logp_topk_idxs[mask] for sum_logp_topk_idxs, mask in zip(beam_total_logp_topk_idxs, beam_masks) + ] + + list_exp = [] # List that contains the hypothesis expansion + list_nb_exp = [] # List that contains the hypothesis expansion + for hyp_idx, hyp in enumerate(hyps): # For all hypothesis + for idx in beam_kexpansions_idxs[hyp_idx]: # For all expansions within this hypothesis + # Restore indices in logp and durations_logp arrays from flattened indices. + k = int(beam_idx_topks[hyp_idx][idx % self.max_candidates]) + duration = self.durations[int(idx // self.max_candidates)] + total_logp = float(beam_total_logp[hyp_idx][idx]) + + # Forcing blank token to have non-zero duration + if k == self.blank and duration == 0: + duration = self.durations[self.min_non_zero_duration_idx] + + new_hyp = Hypothesis( + score=hyp.score + total_logp, + y_sequence=hyp.y_sequence[:], + dec_out=hyp.dec_out[:], + dec_state=hyp.dec_state, + timestep=hyp.timestep[:], + length=time_idx, + last_frame=hyp.last_frame + duration, + ) + + if self.ngram_lm: + new_hyp.ngram_lm_state = hyp.ngram_lm_state + + # If the expansion was for blank + if k == self.blank: + list_b.append(new_hyp) + else: + new_hyp.y_sequence.append(k) + new_hyp.timestep.append(time_idx + duration) + + if self.ngram_lm: + lm_score, new_hyp.ngram_lm_state = self.compute_ngram_score(hyp.ngram_lm_state, int(k)) + new_hyp.score += self.ngram_lm_alpha * lm_score + + # If token duration is 0 adding to expansions list + if duration == 0: + list_exp.append(new_hyp) + else: + list_nb_exp.append(new_hyp) + + # Update states for hypothesis that do not end with blank + hyps_to_update = list_nb_exp + list_exp + if len(hyps_to_update) > 0: + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis( + hyps_to_update, + cache, + ) + for hyp_idx, hyp in enumerate(hyps_to_update): + # Preserve the decoder logits for the current beam + hyp.dec_out.append(beam_decoder_output[hyp_idx]) + hyp.dec_state = beam_state[hyp_idx] + + # If there were no token expansions in any of the hypotheses, + # Early exit + list_nb += list_nb_exp + if not list_exp: + kept_hyps = kept_hyps + list_b + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + break + else: + # If this isn't the last mAES step + if n < (self.maes_num_steps - 1): + # Copy the expanded hypothesis for the next iteration + hyps = self.merge_duplicate_hypotheses(list_exp) + else: + # If this is the last mAES step add probabilities of the blank token to the end. + # Extract the log probabilities + beam_decoder_output = torch.stack([h.dec_out[-1] for h in list_exp]) # [H, 1, D] + beam_logits = ( + self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + ) + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + + # Get most probable durations + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + _, beam_max_duration_idx = torch.max(beam_duration_logp, dim=-1) + + # For all expansions, add the score for the blank label + for hyp_idx, hyp in enumerate(list_exp): + # If zero duration was obtained, change to the closest non-zero duration + duration_idx = int(beam_max_duration_idx[hyp_idx]) + if duration_idx == self.zero_duration_idx: + duration_idx = self.min_non_zero_duration_idx + + total_logp = float( + beam_logp[hyp_idx, self.blank] + beam_duration_logp[hyp_idx, duration_idx] + ) + hyp.score += total_logp + hyp.last_frame += self.durations[duration_idx] + + # Finally, update the kept hypothesis of sorted top Beam candidates + kept_hyps = kept_hyps + list_b + list_exp + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + # Sort the hypothesis with best scores + return self.sort_nbest(kept_hyps) + + def merge_duplicate_hypotheses(self, hypotheses): + """ + Merges hypotheses with identical token sequences and lengths. + The combined hypothesis's probability is the sum of the probabilities of all duplicates. + Duplicate hypotheses occur when two consecutive blank tokens are predicted + and their duration values sum up to the same number. + + Args: + hypotheses: list of hypotheses. + + Returns: + hypotheses: list if hypotheses without duplicates. + """ + sorted_hyps = sorted(hypotheses, key=lambda x: x.score, reverse=True) + kept_hyps = {} + for hyp in sorted_hyps: + hyp_key = (tuple(hyp.y_sequence), int(hyp.last_frame)) + if hyp_key in kept_hyps: + kept_hyp = kept_hyps[hyp_key] + kept_hyp.score = float(torch.logaddexp(torch.tensor(kept_hyp.score), torch.tensor(hyp.score))) + else: + kept_hyps[hyp_key] = hyp + return list(kept_hyps.values()) + + def set_decoding_type(self, decoding_type: str): + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ + # TOKEN_OFFSET for BPE-based models + if decoding_type == 'subword': + from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET + + self.token_offset = DEFAULT_TOKEN_OFFSET + + def prefix_search( + self, hypotheses: List[Hypothesis], encoder_output: torch.Tensor, prefix_alpha: int + ) -> List[Hypothesis]: + """ + Performs a prefix search and updates the scores of the hypotheses in place. + Based on https://arxiv.org/pdf/1211.3711.pdf. + + Args: + hypotheses: a list of hypotheses sorted by the length from the longest to the shortest. + encoder_output: encoder output. + prefix_alpha: maximum allowable length difference between hypothesis and a prefix. + + Returns: + hypotheses: list of hypotheses with updated scores. + """ + # Iterate over hypotheses. + for curr_idx, curr_hyp in enumerate(hypotheses[:-1]): + # For each hypothesis, iterate over the subsequent hypotheses. + # If a hypothesis is a prefix of the current one, update current score. + for pref_hyp in hypotheses[(curr_idx + 1) :]: + curr_hyp_length = len(curr_hyp.y_sequence) + pref_hyp_length = len(pref_hyp.y_sequence) + + if ( + is_prefix(curr_hyp.y_sequence, pref_hyp.y_sequence) + and (curr_hyp_length - pref_hyp_length) <= prefix_alpha + ): + # Compute the score of the first token + # that follows the prefix hypothesis tokens in current hypothesis. + # Use the decoder output, which is stored in the prefix hypothesis. + logits = self.joint.joint(encoder_output, pref_hyp.dec_out[-1]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score = pref_hyp.score + float( + logp[curr_hyp.y_sequence[pref_hyp_length]] + duration_logp[self.zero_duration_idx] + ) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + pref_hyp.ngram_lm_state, int(curr_hyp.y_sequence[pref_hyp_length]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + for k in range(pref_hyp_length, (curr_hyp_length - 1)): + # Compute the score of the next token. + # Approximate decoder output with the one that is stored in current hypothesis. + logits = self.joint.joint(encoder_output, curr_hyp.dec_out[k]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score += float(logp[curr_hyp.y_sequence[k + 1]] + duration_logp[self.zero_duration_idx]) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + next_state, int(curr_hyp.y_sequence[k + 1]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + # Update current hypothesis score + curr_hyp.score = np.logaddexp(curr_hyp.score, curr_score) + return hypotheses + + def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tuple[float, "kenlm.State"]: + """ + Computes the score for KenLM Ngram language model. + + Args: + current_lm_state: current state of the KenLM language model. + label: next label. + + Returns: + lm_score: score for `label`. + """ + if self.token_offset: + label = chr(label + self.token_offset) + else: + label = str(label) + + next_state = kenlm.State() + lm_score = self.ngram_lm.BaseScore(current_lm_state, label, next_state) + lm_score *= 1.0 / np.log10(np.e) + + return lm_score, next_state + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + """ + if self.score_norm: + return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) + else: + return sorted(hyps, key=lambda x: x.score, reverse=True) diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 76e9da6087ed..8d2755fcc0ae 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -85,6 +85,8 @@ class Hypothesis: tokens: (Optional) A list of decoded tokens (can be characters or word-pieces. last_token (Optional): A token or batch of tokens which was predicted in the last step. + + last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction. """ score: float @@ -105,6 +107,7 @@ class Hypothesis: tokens: Optional[Union[List[int], torch.Tensor]] = None last_token: Optional[torch.Tensor] = None token_duration: Optional[List[int]] = None + last_frame: Optional[int] = None @property def non_blank_frame_confidence(self) -> List[float]: @@ -244,7 +247,8 @@ def __init__( Args: batch_size: batch size for hypotheses - init_length: initial estimate for the length of hypotheses (if the real length is higher, tensors will be reallocated) + init_length: initial estimate for the length of hypotheses (if the real length is higher, + tensors will be reallocated) device: device for storing hypotheses float_dtype: float type for scores """ @@ -274,6 +278,9 @@ def __init__( self._ones_batch = torch.ones_like(self._batch_indices) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.transcript.fill_(0) self.timesteps.fill_(0) @@ -497,6 +504,9 @@ def __init__( self._batch_indices = torch.arange(batch_size, device=device) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.timesteps.fill_(0) self.logits.fill_(0.0) diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index f18c828d9d45..67a26609dd51 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -15,7 +15,8 @@ import logging from typing import Any, Dict, Literal -from megatron.energon import DefaultTaskEncoder, get_train_dataset +from megatron.core import parallel_state +from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset from pytorch_lightning.utilities.types import EVAL_DATALOADERS from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule @@ -56,6 +57,9 @@ def __init__( pin_memory: bool = True, task_encoder: DefaultTaskEncoder = None, use_train_split_for_val: bool = False, + virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning + packing_buffer_size: int | None = None, + max_samples_per_sequence: int | None = None, ) -> None: """ Initialize the SimpleMultiModalDataModule. @@ -82,6 +86,10 @@ def __init__( task_encoder=task_encoder, ) self.use_train_split_for_val = use_train_split_for_val + self.virtual_epoch_length = virtual_epoch_length + self.num_workers_val = 1 + self.packing_buffer_size = packing_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): """ @@ -106,29 +114,55 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val batch_size=self.micro_batch_size, task_encoder=self.task_encoder, worker_config=worker_config, - max_samples_per_sequence=None, - shuffle_buffer_size=100, + max_samples_per_sequence=self.max_samples_per_sequence, + shuffle_buffer_size=None, split_part=split, - batch_drop_last=True, - virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + virtual_epoch_length=self.virtual_epoch_length, + packing_buffer_size=self.packing_buffer_size, ) return _dataset def val_dataloader(self) -> EVAL_DATALOADERS: """ - Configure the validation DataLoader. + Initialize and return the validation DataLoader. - This method configures the DataLoader for validation data. - - Parameters: - worker_config: Configuration for the data loader workers. + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. Returns: - DataLoader: The DataLoader for validation data. + EVAL_DATALOADERS: The DataLoader for the validation dataset. """ if self.use_train_split_for_val: return self.train_dataloader() - return super().val_dataloader() + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + message = ( + "Muiltimodal val data loader parallel state is not initialized " + f"using default worker config with no_workers {self.num_workers}" + ) + logging.info(message) + + worker_config = WorkerConfig.default_worker_config(self.num_workers_val) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers_val, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split='val') + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ diff --git a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py new file mode 100644 index 000000000000..6cb686c1c305 --- /dev/null +++ b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils.data import DataLoader + +from nemo.collections.diffusion.models.model import DiTConfig +from nemo.lightning.pytorch.plugins import MegatronDataSampler + +from .diffusion_taskencoder import pos_id_3d + + +class PosEmb3D: + """Generates and provides 3D positional embeddings for video data.""" + + def __init__(self, *, max_t=96, max_h=960, max_w=960): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + """Generates the positional ID grid based on max_t, max_h, and max_w.""" + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device='cpu'), + torch.arange(self.max_h, device='cpu'), + torch.arange(self.max_w, device='cpu'), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + """Retrieves a subset of the positional IDs for the specified dimensions. + + Parameters: + t (int): Number of time frames. + h (int): Height dimension. + w (int): Width dimension. + + Returns: + torch.Tensor: The positional IDs tensor with shape (t, h, w, 3). + """ + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +class DiTVideoLatentFakeDataset(torch.utils.data.Dataset): + """A fake dataset for generating synthetic video latent data.""" + + def __init__( + self, + n_frames, + max_h, + max_w, + patch_size, + in_channels, + crossattn_emb_size, + max_text_seqlen=512, + seq_length=8192, + ): + self.max_t = n_frames + self.max_height = max_h + self.max_width = max_w + self.patch_size = patch_size + self.in_channels = in_channels + self.text_dim = crossattn_emb_size + self.text_seqlen = max_text_seqlen + self.seq_length = seq_length + + def __len__(self): + """Returns the total number of samples.""" + return 100000000 + + def __getitem__(self, idx): + """Generates a single sample of data. + + Parameters: + idx (int): Index of the data sample. + + Returns: + dict: A dictionary containing video latent data and related information. + """ + t = self.max_t + h = self.max_height + w = self.max_width + p = self.patch_size + c = self.in_channels + + video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5 + text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16) + pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3) + + return { + 'video': video_latent, + 't5_text_embeddings': text_embedding, + 'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(), + 'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(), + 'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32), + 'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16), + } + + def _collate_fn(self, batch): + """A default implementation of a collation function. + + Users should override this method to define custom data loaders. + """ + return torch.utils.data.dataloader.default_collate(batch) + + def collate_fn(self, batch): + """Method that user passes as a functor to DataLoader. + + The method optionally performs neural type checking and adds types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns: + Collated batch, with or without types. + """ + return self._collate_fn(batch) + + +class VideoLatentFakeDataModule(pl.LightningDataModule): + """A LightningDataModule for generating fake video latent data for training.""" + + def __init__( + self, + model_config: DiTConfig, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 8, + num_workers: int = 1, + pin_memory: bool = True, + task_encoder=None, + use_train_split_for_val: bool = False, + ) -> None: + super().__init__() + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.model_config = model_config + + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + ) + + def setup(self, stage: str = "") -> None: + """Sets up the dataset for training and validation. + + Parameters: + stage (str): Optional stage argument (unused). + """ + self._train_ds = DiTVideoLatentFakeDataset( + n_frames=self.model_config.max_frames, + max_h=self.model_config.max_img_h, + max_w=self.model_config.max_img_w, + patch_size=self.model_config.patch_spatial, + in_channels=self.model_config.in_channels, + crossattn_emb_size=self.model_config.crossattn_emb_size, + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Returns the training DataLoader.""" + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """Returns the validation DataLoader.""" + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """Creates a DataLoader for the given dataset. + + Parameters: + dataset (Dataset): The dataset to load. + **kwargs: Additional arguments for DataLoader. + + Returns: + DataLoader: The DataLoader instance. + """ + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + collate_fn=dataset.collate_fn, + **kwargs, + ) diff --git a/nemo/collections/diffusion/data/diffusion_taskencoder.py b/nemo/collections/diffusion/data/diffusion_taskencoder.py index 57e4e4ec8673..2a42b15453b3 100644 --- a/nemo/collections/diffusion/data/diffusion_taskencoder.py +++ b/nemo/collections/diffusion/data/diffusion_taskencoder.py @@ -12,15 +12,96 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +import random +from dataclasses import dataclass +from typing import Any, List, Optional + import torch import torch.nn.functional as F from einops import rearrange -from megatron.core import parallel_state -from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon import DefaultTaskEncoder, Sample, SkipSample +from megatron.energon.task_encoder.base import stateless from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from nemo.lightning.io.mixin import IOMixin +from nemo.utils.sequence_packing_utils import first_fit_decreasing + + +@dataclass +class DiffusionSample(Sample): + """ + Data class representing a sample for diffusion tasks. + + Attributes: + video (torch.Tensor): Video latents (C T H W). + t5_text_embeddings (torch.Tensor): Text embeddings (S D). + t5_text_mask (torch.Tensor): Mask for text embeddings. + loss_mask (torch.Tensor): Mask indicating valid positions for loss computation. + image_size (Optional[torch.Tensor]): Tensor containing image dimensions. + fps (Optional[torch.Tensor]): Frame rate of the video. + num_frames (Optional[torch.Tensor]): Number of frames in the video. + padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. + seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. + seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. + pos_ids (Optional[torch.Tensor]): Positional IDs. + latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. + """ + + video: torch.Tensor # video latents (C T H W) + t5_text_embeddings: torch.Tensor # (S D) + t5_text_mask: torch.Tensor # 1 + loss_mask: torch.Tensor + image_size: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + padding_mask: Optional[torch.Tensor] = None + seq_len_q: Optional[torch.Tensor] = None + seq_len_kv: Optional[torch.Tensor] = None + pos_ids: Optional[torch.Tensor] = None + latent_shape: Optional[torch.Tensor] = None + + def to_dict(self) -> dict: + """Converts the sample to a dictionary.""" + return dict( + video=self.video, + t5_text_embeddings=self.t5_text_embeddings, + t5_text_mask=self.t5_text_mask, + loss_mask=self.loss_mask, + image_size=self.image_size, + fps=self.fps, + num_frames=self.num_frames, + padding_mask=self.padding_mask, + seq_len_q=self.seq_len_q, + seq_len_kv=self.seq_len_kv, + pos_ids=self.pos_ids, + latent_shape=self.latent_shape, + ) + + def __add__(self, other: Any) -> int: + """Adds the sequence length of this sample with another sample or integer.""" + if isinstance(other, DiffusionSample): + # Combine the values of the two instances + return self.seq_len_q.item() + other.seq_len_q.item() + elif isinstance(other, int): + # Add an integer to the value + return self.seq_len_q.item() + other + raise NotImplementedError + + def __radd__(self, other: Any) -> int: + """Handles reverse addition for summing with integers.""" + # This is called if sum or other operations start with a non-DiffusionSample object. + # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__. + if isinstance(other, int): + return self.seq_len_q.item() + other + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + """Compares this sample's sequence length with another sample or integer.""" + if isinstance(other, DiffusionSample): + return self.seq_len_q.item() < other.seq_len_q.item() + elif isinstance(other, int): + return self.seq_len_q.item() < other + raise NotImplementedError def cook(sample: dict) -> dict: @@ -75,18 +156,26 @@ def __init__( max_frames: int = None, text_embedding_padding_size: int = 512, seq_length: int = None, + max_seq_length: int = None, patch_spatial: int = 2, patch_temporal: int = 1, + aesthetic_score: float = 0.0, **kwargs, ): super().__init__(*args, **kwargs) self.max_frames = max_frames self.text_embedding_padding_size = text_embedding_padding_size self.seq_length = seq_length + self.max_seq_length = max_seq_length self.patch_spatial = patch_spatial self.patch_temporal = patch_temporal + self.aesthetic_score = aesthetic_score + @stateless(restore_seeds=True) def encode_sample(self, sample: dict) -> dict: + """ + Encodes video / text sample. + """ video_latent = sample['pth'] if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): @@ -95,6 +184,9 @@ def encode_sample(self, sample: dict) -> dict: raise SkipSample() info = sample['json'] + if info['aesthetic_score'] < self.aesthetic_score: + raise SkipSample() + C, T, H, W = video_latent.shape seq_len = ( video_latent.shape[-1] @@ -105,19 +197,14 @@ def encode_sample(self, sample: dict) -> dict: ) is_image = T == 1 - if seq_len > self.seq_length: + if self.seq_length is not None and seq_len > self.seq_length: + raise SkipSample() + if self.max_seq_length is not None and seq_len > self.max_seq_length: raise SkipSample() if self.max_frames is not None: video_latent = video_latent[:, : self.max_frames, :, :] - tpcp_size = parallel_state.get_tensor_model_parallel_world_size() - if parallel_state.get_context_parallel_world_size() > 1: - tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 - if (T * H * W) % tpcp_size != 0: - warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') - raise SkipSample() - video_latent = rearrange( video_latent, 'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)', @@ -161,7 +248,7 @@ def encode_sample(self, sample: dict) -> dict: 'T H W d -> (T H W) d', ) - if self.seq_length is not None: + if self.seq_length is not None and self.max_seq_length is None: pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) loss_mask[:seq_len] = 1 @@ -169,7 +256,11 @@ def encode_sample(self, sample: dict) -> dict: else: loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) - return dict( + return DiffusionSample( + __key__=sample['__key__'], + __restore_key__=sample['__restore_key__'], + __subflavor__=None, + __subflavors__=sample['__subflavors__'], video=video_latent, t5_text_embeddings=t5_text_embeddings, t5_text_mask=t5_text_mask, @@ -178,13 +269,86 @@ def encode_sample(self, sample: dict) -> dict: num_frames=num_frames, loss_mask=loss_mask, seq_len_q=torch.tensor(seq_len, dtype=torch.int32), - seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32), + seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) + def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: + """ + Selects sequences to pack for mixed image-video training. + """ + results = first_fit_decreasing(samples, self.max_seq_length) + random.shuffle(results) + return results + + @stateless + def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: + """Construct a new Diffusion sample by concatenating the sequences.""" + + def stack(attr): + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + + def cat(attr): + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + + video = concat_pad([i.video for i in samples], self.max_seq_length) + loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length) + pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length) + + return DiffusionSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=video, + t5_text_embeddings=cat('t5_text_embeddings'), + t5_text_mask=cat('t5_text_mask'), + # image_size=stack('image_size'), + # fps=stack('fps'), + # num_frames=stack('num_frames'), + loss_mask=loss_mask, + seq_len_q=stack('seq_len_q'), + seq_len_kv=stack('seq_len_kv'), + pos_ids=pos_ids, + latent_shape=stack('latent_shape'), + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + """Return dictionary with data for batch.""" + if self.max_seq_length is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + return dict( + video=sample.video.unsqueeze_(0), + t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0), + t5_text_mask=sample.t5_text_mask.unsqueeze_(0), + loss_mask=sample.loss_mask.unsqueeze_(0), + # image_size=sample.image_size, + # fps=sample.fps, + # num_frames=sample.num_frames, + # padding_mask=sample.padding_mask.unsqueeze_(0), + seq_len_q=sample.seq_len_q, + seq_len_kv=sample.seq_len_kv, + pos_ids=sample.pos_ids.unsqueeze_(0), + latent_shape=sample.latent_shape, + ) + class PosID3D: + """ + Generates 3D positional IDs for video data. + + Attributes: + max_t (int): Maximum number of time frames. + max_h (int): Maximum height dimension. + max_w (int): Maximum width dimension. + """ + def __init__(self, *, max_t=32, max_h=128, max_w=128): self.max_t = max_t self.max_h = max_h @@ -192,6 +356,7 @@ def __init__(self, *, max_t=32, max_h=128, max_w=128): self.generate_pos_id() def generate_pos_id(self): + """Generates a grid of positional IDs based on max_t, max_h, and max_w.""" self.grid = torch.stack( torch.meshgrid( torch.arange(self.max_t, device='cpu'), @@ -202,6 +367,7 @@ def generate_pos_id(self): ) def get_pos_id_3d(self, *, t, h, w): + """Retrieves positional IDs for specified dimensions.""" if t > self.max_t or h > self.max_h or w > self.max_w: self.max_t = max(self.max_t, t) self.max_h = max(self.max_h, h) @@ -210,4 +376,70 @@ def get_pos_id_3d(self, *, t, h, w): return self.grid[:t, :h, :w] +def pad_divisible(x, padding_value=0): + """ + Pads the input tensor to make its size divisible by a specified value. + + Args: + x (torch.Tensor): Input tensor. + padding_value (int): The value to make the tensor size divisible by. + + Returns: + torch.Tensor: Padded tensor. + """ + if padding_value == 0: + return x + # Get the size of the first dimension + n = x.size(0) + + # Compute the padding needed to make the first dimension divisible by 16 + padding_needed = (padding_value - n % padding_value) % padding_value + + if padding_needed <= 0: + return x + + # Create a new shape with the padded first dimension + new_shape = list(x.shape) + new_shape[0] += padding_needed + + # Create a new tensor filled with zeros + x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device) + + # Assign the original tensor to the beginning of the new tensor + x_padded[:n] = x + return x_padded + + +def concat_pad(tensor_list, max_seq_length): + """ + Efficiently concatenates a list of tensors along the first dimension and pads with zeros + to reach max_seq_length. + + Args: + tensor_list (list of torch.Tensor): List of tensors to concatenate and pad. + max_seq_length (int): The desired size of the first dimension of the output tensor. + + Returns: + torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions. + """ + import torch + + # Get common properties from the first tensor + other_shape = tensor_list[0].shape[1:] + dtype = tensor_list[0].dtype + device = tensor_list[0].device + + # Initialize the result tensor with zeros + result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device) + + current_index = 0 + for tensor in tensor_list: + length = tensor.shape[0] + # Directly assign the tensor to the result tensor without checks + result[current_index : current_index + length] = tensor + current_index += length + + return result + + pos_id_3d = PosID3D() diff --git a/nemo/collections/diffusion/models/dit/dit_embeddings.py b/nemo/collections/diffusion/models/dit/dit_embeddings.py index ec8d095cbbd4..6303db43bba1 100644 --- a/nemo/collections/diffusion/models/dit/dit_embeddings.py +++ b/nemo/collections/diffusion/models/dit/dit_embeddings.py @@ -55,6 +55,12 @@ def __init__(self, in_channels: int, time_embed_dim: int, seed=None): self.linear_1.reset_parameters() self.linear_2.reset_parameters() + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.linear_1.weight, "pipeline_parallel", True) + setattr(self.linear_1.bias, "pipeline_parallel", True) + setattr(self.linear_2.weight, "pipeline_parallel", True) + setattr(self.linear_2.bias, "pipeline_parallel", True) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Computes the positional embeddings for the input tensor. @@ -152,10 +158,27 @@ def __init__( self.emb_h = torch.nn.Embedding(h, config.hidden_size) self.emb_w = torch.nn.Embedding(w, config.hidden_size) - if config.perform_initialization: - config.init_method(self.emb_t.weight) - config.init_method(self.emb_h.weight) - config.init_method(self.emb_w.weight) + if 'seed' in kwargs.keys(): + seed = kwargs['seed'] + with torch.random.fork_rng(): + torch.manual_seed(seed) + if config.perform_initialization: + self.customize_init_param() + else: + self.reset_parameters() + else: + if config.perform_initialization: + self.customize_init_param() + + def customize_init_param(self): + self.config.init_method(self.emb_t.weight) + self.config.init_method(self.emb_h.weight) + self.config.init_method(self.emb_w.weight) + + def reset_parameters(self): + self.emb_t.reset_parameters() + self.emb_h.reset_parameters() + self.emb_w.reset_parameters() def forward(self, pos_ids: torch.Tensor): return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2]) diff --git a/nemo/collections/diffusion/models/dit/dit_layer_spec.py b/nemo/collections/diffusion/models/dit/dit_layer_spec.py index cb7c520493f0..2233ef3a7354 100644 --- a/nemo/collections/diffusion/models/dit/dit_layer_spec.py +++ b/nemo/collections/diffusion/models/dit/dit_layer_spec.py @@ -733,8 +733,8 @@ def get_stdit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: ) -def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: - params = {"attn_mask_type": AttnMaskType.padding} +def get_dit_adaln_block_with_transformer_engine_spec(attn_mask_type=AttnMaskType.padding) -> ModuleSpec: + params = {"attn_mask_type": attn_mask_type} return ModuleSpec( module=DiTLayerWithAdaLN, submodules=DiTWithAdaLNSubmodules( diff --git a/nemo/collections/diffusion/models/dit/dit_model.py b/nemo/collections/diffusion/models/dit/dit_model.py index 0c1c1abc82f2..24943de6d675 100644 --- a/nemo/collections/diffusion/models/dit/dit_model.py +++ b/nemo/collections/diffusion/models/dit/dit_model.py @@ -141,7 +141,7 @@ def __init__( self.config: TransformerConfig = config - self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec(attn_mask_type=config.attn_mask_type) self.pre_process = pre_process self.post_process = post_process self.add_encoder = True @@ -173,19 +173,33 @@ def __init__( dit_embeddings.ParallelTimestepEmbedding(self.config.hidden_size, self.config.hidden_size, seed=1234), ) + self.fps_embedder = nn.Sequential( + Timesteps(num_channels=256, flip_sin_to_cos=False, downscale_freq_shift=1), + ParallelTimestepEmbedding(256, 256, seed=1234), + ) + if self.pre_process: self.x_embedder = torch.nn.Linear(in_channels * patch_spatial**2, self.config.hidden_size) + if pos_embedder is dit_embeddings.SinCosPosEmb3D: + if self.pre_process: + self.pos_embedder = pos_embedder( + config, + t=max_frames // patch_temporal, + h=max_img_h // patch_spatial, + w=max_img_w // patch_spatial, + ) + else: self.pos_embedder = pos_embedder( config, t=max_frames // patch_temporal, h=max_img_h // patch_spatial, w=max_img_w // patch_spatial, + seed=1234, ) - self.fps_embedder = nn.Sequential( - Timesteps(num_channels=256, flip_sin_to_cos=False, downscale_freq_shift=1), - ParallelTimestepEmbedding(256, 256), - ) + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + for p in self.pos_embedder.parameters(): + setattr(p, "pipeline_parallel", True) if self.post_process: self.final_layer_linear = torch.nn.Linear( @@ -194,6 +208,8 @@ def __init__( ) self.affline_norm = RMSNorm(self.config.hidden_size) + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.affline_norm.weight, "pipeline_parallel", True) def forward( self, @@ -223,6 +239,7 @@ def forward( ] * B, dtype=torch.bfloat16, + device=x.device, ), ).view(-1) if self.pre_process: @@ -234,10 +251,16 @@ def forward( else: pos_emb = self.pos_embedder(pos_ids) pos_emb = rearrange(pos_emb, "B S D -> S B D") - x_S_B_D = rearrange(x_B_S_D, "B S D -> S B D") + x_S_B_D = rearrange(x_B_S_D, "B S D -> S B D").contiguous() else: # intermediate stage of pipeline x_S_B_D = None ### should it take encoder_hidden_states + if (not hasattr(self, "pos_embedder")) or isinstance(self.pos_embedder, dit_embeddings.SinCosPosEmb3D): + pos_emb = None + else: + ## if transformer blocks need pos_emb, then pos_embedder should + ## be replicated across pp ranks. + pos_emb = rearrange(self.pos_embedder(pos_ids), "B S D -> S B D").contiguous() timesteps_B_D = self.t_embedder(timesteps.flatten()).to(torch.bfloat16) # (b d_text_embedding) @@ -245,12 +268,17 @@ def forward( fps_B_D = self.fps_embedder(fps) fps_B_D = nn.functional.pad(fps_B_D, (0, self.config.hidden_size - fps_B_D.shape[1])) affline_emb_B_D += fps_B_D + affline_emb_B_D = self.affline_norm(affline_emb_B_D) - crossattn_emb = rearrange(crossattn_emb, 'B S D -> S B D') + crossattn_emb = rearrange(crossattn_emb, 'B S D -> S B D').contiguous() if self.config.sequence_parallel: if self.pre_process: x_S_B_D = tensor_parallel.scatter_to_sequence_parallel_region(x_S_B_D) + if hasattr(self, "pos_embedder") and isinstance( + self.pos_embedder, dit_embeddings.FactorizedLearnable3DEmbedding + ): + pos_emb = tensor_parallel.scatter_to_sequence_parallel_region(pos_emb) crossattn_emb = tensor_parallel.scatter_to_sequence_parallel_region(crossattn_emb) # `scatter_to_sequence_parallel_region` returns a view, which prevents # the original tensor from being garbage collected. Clone to facilitate GC. @@ -309,51 +337,41 @@ def sharded_state_dict( """ sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - for param_name, param in self.t_embedder.named_parameters(): - weight_key = f'{prefix}t_embedder.{param_name}' - self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) - - for param_name, param in self.affline_norm.named_parameters(): - weight_key = f'{prefix}affline_norm.{param_name}' - self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) - + for module in ['t_embedder']: + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f'{prefix}{module}.{param_name}' + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) return sharded_state_dict - def tie_embeddings_weights_state_dict( - self, - tensor, - sharded_state_dict: ShardedStateDict, - output_layer_weight_key: str, - first_stage_word_emb_key: str, + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str ) -> None: - """Ties the embedding and output weights in a given sharded state dict. + """set replica ids of the weights in t_embedder for sharded state dict. Args: sharded_state_dict (ShardedStateDict): state dict with the weight to tie - output_layer_weight_key (str): key of the output layer weight in the state dict. + weight_key (str): key of the weight in the state dict. This entry will be replaced with a tied version - first_stage_word_emb_key (str): this must be the same as the - ShardedTensor.key of the first stage word embeddings. Returns: None, acts in-place """ - if self.pre_process and parallel_state.get_tensor_model_parallel_rank() == 0: - # Output layer is equivalent to the embedding already - return - - # Replace the default output layer with a one sharing the weights with the embedding - del sharded_state_dict[output_layer_weight_key] - last_stage_word_emb_replica_id = ( - 0, # copy of first stage embedding - parallel_state.get_tensor_model_parallel_rank() - + parallel_state.get_pipeline_model_parallel_rank() - * parallel_state.get_pipeline_model_parallel_world_size(), + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + if embedder_weight_key in sharded_state_dict: + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), parallel_state.get_data_parallel_rank(with_context_parallel=True), ) - sharded_state_dict[output_layer_weight_key] = make_sharded_tensor_for_checkpoint( + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, + key=embedder_weight_key, + replica_id=replica_id, allow_shape_mismatch=False, ) diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py index 80bed5878e1b..305db1f2c993 100644 --- a/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py @@ -13,7 +13,7 @@ # limitations under the License. import copy -from typing import Literal +from typing import Literal, Optional from megatron.core.transformer.attention import ( CrossAttention, @@ -22,13 +22,18 @@ SelfAttentionSubmodules, ) from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, + TENorm, + TERowParallelGroupedLinear, TERowParallelLinear, ) from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_block import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig @@ -78,7 +83,7 @@ def _replace_no_cp_submodules(submodules): layer_number=layer_number, ) - self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6) # , norm=TENorm) + self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6, norm=TENorm) def forward( self, @@ -138,8 +143,57 @@ def forward( return output, context -def get_dit_llama_spec() -> ModuleSpec: - params = {"attn_mask_type": AttnMaskType.padding} +def _get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, +) -> ModuleSpec: + """Helper function to get module spec for MLP/MoE""" + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + if use_te and moe_grouped_gemm: + linear_fc1 = TEColumnParallelGroupedLinear + linear_fc2 = TERowParallelGroupedLinear + elif use_te and fp8: + linear_fc1 = TEColumnParallelLinear + linear_fc2 = TERowParallelLinear + else: + raise ValueError("Invalid combination of use_te and moe_grouped_gemm") + + use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None + + return ModuleSpec( + module=MoELayer, + submodules=MoESubmodules( + experts=( + MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) + if not moe_grouped_gemm or use_te_grouped_gemm + else None + ), + shared_experts=ModuleSpec( + module=SharedExpertMLP, + params={"gate": False}, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_dit_llama_spec(num_experts=None, attn_mask_type=AttnMaskType.padding) -> ModuleSpec: + params = {"attn_mask_type": attn_mask_type} return ModuleSpec( module=MoviegGenLayer, submodules=TransformerLayerSubmodules( @@ -162,12 +216,6 @@ def get_dit_llama_spec() -> ModuleSpec: linear_proj=TERowParallelLinear, ), ), - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ), - ), + mlp=_get_mlp_module_spec(use_te=True, num_experts=num_experts, moe_grouped_gemm=True, fp8=None), ), ) diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py index bfa79e366cac..8ec0c7097c63 100644 --- a/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial from typing import Literal from megatron.core.transformer.transformer_config import TransformerConfig @@ -54,7 +54,9 @@ def __init__( patch_temporal=patch_temporal, in_channels=in_channels, out_channels=out_channels, - transformer_decoder_layer_spec=get_dit_llama_spec, + transformer_decoder_layer_spec=partial( + get_dit_llama_spec, num_experts=config.num_moe_experts, attn_mask_type=config.attn_mask_type + ), pos_embedder=dit_embeddings.FactorizedLearnable3DEmbedding, **kwargs, ) diff --git a/nemo/collections/diffusion/models/model.py b/nemo/collections/diffusion/models/model.py index 8cc6be860585..9ee0ab441700 100644 --- a/nemo/collections/diffusion/models/model.py +++ b/nemo/collections/diffusion/models/model.py @@ -14,7 +14,7 @@ import importlib import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple import numpy as np @@ -24,6 +24,7 @@ from einops import rearrange from megatron.core import parallel_state from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig from torch import nn from typing_extensions import override @@ -39,10 +40,12 @@ def dit_forward_step(model, batch) -> torch.Tensor: + """Forward pass of DiT.""" return model(**batch) def dit_data_step(module, dataloader_iter): + """DiT data batch preparation.""" batch = next(dataloader_iter)[0] batch = get_batch_on_this_cp_rank(batch) batch = {k: v.to(device='cuda', non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} @@ -58,12 +61,12 @@ def dit_data_step(module, dataloader_iter): 'self_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), 'cross_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens_kv, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), } @@ -77,9 +80,7 @@ def get_batch_on_this_cp_rank(data: Dict): cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() - t = 16 if cp_size > 1: - assert t % cp_size == 0, "t must divisibly by cp_size" num_valid_tokens_in_ub = None if 'loss_mask' in data and data['loss_mask'] is not None: num_valid_tokens_in_ub = data['loss_mask'].sum() @@ -88,9 +89,13 @@ def get_batch_on_this_cp_rank(data: Dict): if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']): if len(value.shape) > 5: value = value.squeeze(0) - B, C, T, H, W = value.shape + if len(value.shape) == 5: + B, C, T, H, W = value.shape + data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() + else: + B, S, D = value.shape + data[key] = value.view(B, cp_size, S // cp_size, D)[:, cp_rank, ...].contiguous() # TODO: sequence packing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() loss_mask = data["loss_mask"] data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ :, cp_rank, ... @@ -142,8 +147,16 @@ class DiTConfig(TransformerConfig, io.IOMixin): data_step_fn = dit_data_step forward_step_fn = dit_forward_step + replicated_t_embedder = True + + seq_length: int = 2048 + + qkv_format: str = 'sbhd' + attn_mask_type: AttnMaskType = AttnMaskType.no_mask + @override def configure_model(self, tokenizer=None) -> DiTCrossAttentionModel: + """Configure DiT Model from MCore.""" vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size @@ -168,11 +181,14 @@ def configure_model(self, tokenizer=None) -> DiTCrossAttentionModel: ) def configure_vae(self): + """Dynamically import video tokenizer.""" return dynamic_import(self.vae_module)(self.vae_path) @dataclass class DiTBConfig(DiTConfig): + """DiT-B""" + num_layers: int = 12 hidden_size: int = 768 num_attention_heads: int = 12 @@ -180,6 +196,8 @@ class DiTBConfig(DiTConfig): @dataclass class DiTLConfig(DiTConfig): + """DiT-L""" + num_layers: int = 24 hidden_size: int = 1024 num_attention_heads: int = 16 @@ -187,6 +205,8 @@ class DiTLConfig(DiTConfig): @dataclass class DiTXLConfig(DiTConfig): + """DiT-XL""" + num_layers: int = 28 hidden_size: int = 1152 num_attention_heads: int = 16 @@ -194,6 +214,8 @@ class DiTXLConfig(DiTConfig): @dataclass class DiT7BConfig(DiTConfig): + """DiT-7B""" + num_layers: int = 32 hidden_size: int = 3072 num_attention_heads: int = 24 @@ -201,6 +223,8 @@ class DiT7BConfig(DiTConfig): @dataclass class DiTLlama30BConfig(DiTConfig): + """MovieGen 30B""" + num_layers: int = 48 hidden_size: int = 6144 ffn_hidden_size: int = 16384 @@ -228,13 +252,42 @@ class DiTLlama30BConfig(DiTConfig): @dataclass class DiTLlama5BConfig(DiTLlama30BConfig): + """MovieGen 5B""" + num_layers: int = 32 hidden_size: int = 3072 ffn_hidden_size: int = 8192 num_attention_heads: int = 24 +@dataclass +class DiTLlama1BConfig(DiTLlama30BConfig): + """MovieGen 1B""" + + num_layers: int = 16 + hidden_size: int = 2048 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + + +@dataclass +class ECDiTLlama1BConfig(DiTLlama1BConfig): + "EC-DiT 1B" + moe_router_load_balancing_type: str = 'expert_choice' + moe_token_dispatcher_type: str = 'alltoall' + moe_grouped_gemm: bool = True + moe_expert_capacity_factor: float = 8 + moe_pad_expert_input_to_capacity: bool = True + moe_router_topk: int = 1 + num_moe_experts: int = 64 + ffn_hidden_size: int = 1024 + + class DiTModel(GPTModel): + """ + Diffusion Transformer Model + """ + def __init__( self, config: Optional[DiTConfig] = None, @@ -256,6 +309,9 @@ def __init__( self.vae = None + def load_state_dict(self, state_dict, strict=False): + self.module.load_state_dict(state_dict, strict=False) + def data_step(self, dataloader_iter) -> Dict[str, Any]: return self.config.data_step_fn(dataloader_iter) @@ -284,10 +340,12 @@ def on_validation_start(self): self.vae.to('cuda') def on_validation_end(self): + """Move video tokenizer to CPU after validation.""" if self.vae is not None: self.vae.to('cpu') def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + """Generated validation video sample and logs to wandb.""" # In mcore the loss-function is part of the forward-pass (when labels are provided) state_shape = batch['video'].shape sample = self.diffusion_pipeline.generate_samples_from_batch( @@ -304,7 +362,7 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: seq_len_q = batch['seq_len_q'][0] sample = rearrange( - sample[:, :seq_len_q], + sample[0, None, :seq_len_q], 'B (T H W) (ph pw pt C) -> B C (T pt) (H ph) (W pw)', ph=self.config.patch_spatial, pw=self.config.patch_spatial, @@ -318,13 +376,7 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: video = (video * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) - T = video.shape[2] - if T == 1: - image = rearrange(video, 'b c t h w -> (b t h) w c') - result = image - else: - # result = wandb.Video(video, fps=float(batch['fps'])) # (batch, time, channel, height width) - result = video + result = rearrange(video, 'b c t h w -> (b t) c h w') # wandb is on the last rank for megatron, first rank for nemo wandb_rank = 0 @@ -340,11 +392,12 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: if gather_list is not None: videos = [] for video in gather_list: - if len(video.shape) == 3: - videos.append(wandb.Image(video)) - else: - videos.append(wandb.Video(video, fps=30)) - wandb.log({'prediction': videos}, step=self.global_step) + try: + videos.append(wandb.Video(video, fps=24, format='mp4')) + except Exception as e: + warnings.warn(f'Error saving video as mp4: {e}') + videos.append(wandb.Video(video, fps=24)) + wandb.log({'prediction': videos}) return None @@ -375,6 +428,10 @@ def on_validation_model_zero_grad(self) -> None: class DummyLossReduction(MegatronLossReduction): + """ + Diffusion Loss Reduction + """ + def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> None: super().__init__() self.validation_step = validation_step diff --git a/nemo/collections/diffusion/sampler/edm/edm_pipeline.py b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py index 6e1be1f6f2a6..16d3177088a9 100644 --- a/nemo/collections/diffusion/sampler/edm/edm_pipeline.py +++ b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py @@ -427,8 +427,13 @@ def get_data_and_condition(self, data_batch: dict[str, Tensor], dropout_rate=0.2 latent_state = raw_state # Condition - data_batch['crossattn_emb'] = self.random_dropout_input( + condition = {} # Create a new dictionary for condition + # Copy all keys from data_batch except 'video' + for key, value in data_batch.items(): + if key not in ['video', 't5_text_embeddings']: + condition[key] = value + condition['crossattn_emb'] = self.random_dropout_input( data_batch['t5_text_embeddings'], dropout_rate=dropout_rate ) - return raw_state, latent_state, data_batch + return raw_state, latent_state, condition diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py index 43a0a5dcb536..5428e0eeefa2 100644 --- a/nemo/collections/diffusion/train.py +++ b/nemo/collections/diffusion/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,29 +19,38 @@ import torch from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.enums import AttnMaskType from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule +from nemo.collections.diffusion.data.diffusion_fake_datamodule import VideoLatentFakeDataModule from nemo.collections.diffusion.data.diffusion_taskencoder import BasicDiffusionTaskEncoder from nemo.collections.diffusion.models.model import ( DiT7BConfig, DiTConfig, DiTLConfig, + DiTLlama1BConfig, DiTLlama5BConfig, DiTLlama30BConfig, DiTModel, DiTXLConfig, + ECDiTLlama1BConfig, ) +from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils.exp_manager import TimingCallback @run.cli.factory @run.autoconvert def multimodal_datamodule() -> pl.LightningDataModule: + """Multimodal Datamodule Initialization""" data_module = DiffusionDataModule( seq_length=2048, task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), @@ -51,9 +60,39 @@ def multimodal_datamodule() -> pl.LightningDataModule: return data_module +@run.cli.factory +@run.autoconvert +def simple_datamodule() -> pl.LightningDataModule: + """Simple Datamodule Initialization""" + data_module = SimpleMultiModalDataModule( + seq_length=2048, + micro_batch_size=1, + global_batch_size=32, + num_workers=16, + tokenizer=None, + image_processor=None, + task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), + ) + return data_module + + +@run.cli.factory +@run.autoconvert +def multimodal_fake_datamodule() -> pl.LightningDataModule: + """Multimodal Mock Datamodule Initialization""" + data_module = VideoLatentFakeDataModule( + seq_length=None, # Set None to dectect the sequence length automatically. + task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), + micro_batch_size=1, + global_batch_size=32, + ) + return data_module + + @run.cli.factory @run.autoconvert def peft(args) -> ModelTransform: + """Parameter Efficient Fine Tuning""" return llm.peft.LoRA( target_modules=['linear_qkv', 'linear_proj'], # , 'linear_fc1', 'linear_fc2'], dim=args.lora_dim, @@ -62,6 +101,7 @@ def peft(args) -> ModelTransform: @run.cli.factory(target=llm.train) def pretrain() -> run.Partial: + """Base Pretraining Config""" return run.Partial( llm.train, model=run.Config( @@ -85,6 +125,8 @@ def pretrain() -> run.Partial: DistributedDataParallelConfig, check_for_nan_in_grad=True, grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, ), ), plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), @@ -96,12 +138,18 @@ def pretrain() -> run.Partial: callbacks=[ run.Config( ModelCheckpoint, - monitor='reduced_train_loss', - filename='{epoch}-{step}', + monitor='global_step', + filename='{global_step}', every_n_train_steps=1000, - save_top_k=-1, + save_top_k=3, + mode='max', ), run.Config(PreemptionCallback), + run.Config(TimingCallback), + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ), ], ), log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None)), @@ -129,6 +177,7 @@ def pretrain() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_xl() -> run.Partial: + """DiT-XL Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiTXLConfig) return recipe @@ -136,13 +185,89 @@ def pretrain_xl() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_l() -> run.Partial: + """DiT-L Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiTLConfig) return recipe +@run.cli.factory(target=llm.train) +def train_mock() -> run.Partial: + """DiT Mock Pretraining Recipe""" + recipe = pretrain() + recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) + recipe.data = multimodal_fake_datamodule() + recipe.model.config.num_layers = 16 + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.context_parallel_size = 2 + recipe.data.micro_batch_size = 1 + recipe.data.global_batch_size = 1 + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/train_mock' + + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + + return recipe + + +@run.cli.factory(target=llm.train) +def mock_ditllama5b_8k() -> run.Partial: + recipe = pretrain() + recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) + recipe.data = multimodal_fake_datamodule() + recipe.data.seq_length = recipe.data.task_encoder.seq_length = 8192 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.context_parallel_size = 1 + recipe.data.micro_batch_size = 1 + recipe.data.global_batch_size = 32 + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k' + recipe.model.config.attn_mask_type = AttnMaskType.no_mask + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + recipe.trainer.max_steps = 15 + recipe.trainer.callbacks.pop(0) + recipe.trainer.enable_checkpointing = False + recipe.trainer.callbacks.append( + run.Config( + NsysCallback, + start_step=10, + end_step=11, + ) + ) + recipe.resume = None + return recipe + + +@run.cli.factory(target=llm.train) +def mock_dit7b_8k() -> run.Partial: + recipe = mock_ditllama5b_8k() + recipe.model.config = run.Config(DiT7BConfig, max_frames=1) + recipe.data.model_config = recipe.model.config + recipe.model.config.attn_mask_type = AttnMaskType.no_mask + recipe.model.config.use_cpu_initialization = True + recipe.log.log_dir = 'nemo_experiments/mock_dit7b_8k' + return recipe + + @run.cli.factory(target=llm.train) def pretrain_7b() -> run.Partial: + """DiT-7B Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiT7BConfig) recipe.data.global_batch_size = 4608 @@ -161,8 +286,59 @@ def pretrain_7b() -> run.Partial: return recipe +@run.cli.factory(target=llm.train) +def pretrain_7b_pack() -> run.Partial: + """DiT-7B Pretraining Recipe with Packing""" + recipe = pretrain_7b() + recipe.data.global_batch_size = 4608 // 9 + recipe.data.micro_batch_size = 1 + recipe.data.num_workers = 15 + recipe.data.use_train_split_for_val = True + recipe.data.seq_length = 256 * 9 + recipe.data.packing_buffer_size = 1000 + recipe.data.task_encoder.seq_length = None + recipe.data.task_encoder.max_seq_length = recipe.data.seq_length + recipe.model.config.qkv_format = 'thd' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_7b_256p_joint() -> run.Partial: + """DiT-7B Pretraining Recipe 256p Stage 1""" + recipe = pretrain_7b() + recipe.data.global_batch_size = 256 # 768 + recipe.data.micro_batch_size = 1 + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.model.config.seq_length = 8192 + + recipe.optim.config.lr = 6e-5 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + # recipe.resume.restore_config = run.Config(RestoreConfig, path='', load_optim_state=True) + recipe.log.log_dir = 'nemo_experiments/pretrain_7b_256p_joint' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_7b_256p_joint_pack() -> run.Partial: + """DiT-7B Pretraining Recipe 256p Stage 1 with Packing""" + recipe = pretrain_7b_256p_joint() + recipe.data.global_batch_size = 128 + recipe.data.micro_batch_size = 1 + recipe.data.num_workers = 10 + recipe.data.seq_length = recipe.model.config.seq_length = recipe.data.task_encoder.max_seq_length = 10240 + recipe.data.task_encoder.seq_length = None + recipe.data.packing_buffer_size = 1000 + recipe.data.virtual_epoch_length = 0 + recipe.model.config.qkv_format = 'thd' + return recipe + + @run.cli.factory(target=llm.train) def pretrain_ditllama5b() -> run.Partial: + """MovieGen 5B Training""" recipe = pretrain_7b() recipe.data.micro_batch_size = 12 recipe.model.config = run.Config(DiTLlama5BConfig) @@ -172,30 +348,200 @@ def pretrain_ditllama5b() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_ditllama30b() -> run.Partial: + """MovieGen 30B Stage 1 Training""" recipe = pretrain_ditllama5b() recipe.model.config = run.Config(DiTLlama30BConfig) recipe.data.global_batch_size = 9216 recipe.data.micro_batch_size = 6 - recipe.log.log_dir = 'nemo_experiments/ditllama30b' + recipe.data.task_encoder.aethetic_score = 4.0 + recipe.data.seq_length = 256 + recipe.data.task_encoder.seq_length = 256 + recipe.data.virtual_epoch_length = 0 + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage2_mock() -> run.Partial: + """MovieGen 30B Stage 2 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 4 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage3_mock() -> run.Partial: + """MovieGen 30B Stage 3 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama5b_stage3_mock_with_pp() -> run.Partial: + """MovieGen 30B Stage 3 Training""" + recipe = pretrain_ditllama5b() + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.data.global_batch_size = 1 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage5_mock_with_pp' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage3_mock_with_pp() -> run.Partial: + """MovieGen 30B Stage 3 Training with Pipeline Parallelism""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.context_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock_with_pp' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama1b() -> run.Partial: + """MovieGen 1B Stage 1 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama1BConfig) + recipe.data.task_encoder.aethetic_score = 4.0 + recipe.data.seq_length = 256 + recipe.data.task_encoder.seq_length = 256 + recipe.model.config.seq_length = 256 + recipe.data.global_batch_size = 1536 + recipe.data.micro_batch_size = 96 + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.log.log_dir = 'nemo_experiments/ditllama1b' + recipe.trainer.val_check_interval = 3000 + recipe.trainer.callbacks[0].every_n_train_steps = 3000 + recipe.trainer.callbacks[0].monitor = 'global_step' + recipe.trainer.callbacks[0].save_top_k = 3 + recipe.trainer.callbacks[0].mode = 'max' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama3b() -> run.Partial: + """MovieGen 3B Stage 1 Training""" + recipe = pretrain_ditllama1b() + recipe.data.micro_batch_size = 48 + recipe.model.config = run.Config( + DiTLlama1BConfig, + hidden_size=3072, + num_layers=28, + num_attention_heads=24, + ffn_hidden_size=8192, + ) + recipe.log.log_dir = 'nemo_experiments/ditllama3b' + + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ecditllama1b() -> run.Partial: + """EC-DiT 1B Training""" + recipe = pretrain_ditllama1b() + recipe.data.task_encoder.aethetic_score = 5.0 + recipe.data.micro_batch_size = 72 + recipe.data.global_batch_size = 2304 + recipe.model.config = run.Config(ECDiTLlama1BConfig) + recipe.log.log_dir = 'nemo_experiments/ecditllama1b' + recipe.trainer.val_check_interval = 3000 + + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe @run.cli.factory(target=llm.train) def dreambooth() -> run.Partial: + """Dreambooth Fine Tuning""" recipe = pretrain() recipe.optim.config.lr = 1e-6 recipe.data = multimodal_datamodule() recipe.model.config = run.Config(DiTConfig) - recipe.trainer.max_steps = 1000 recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.sequence_parallel = True - recipe.resume.restore_config = run.Config(RestoreConfig) recipe.resume.resume_if_exists = False - return recipe if __name__ == "__main__": + OOM_DEBUG = False + if OOM_DEBUG: + torch.cuda.memory._record_memory_history( + True, + # Keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # Record stack information for the trace events + trace_alloc_record_context=True, + ) run.cli.main(llm.train, default_factory=dreambooth) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 13f25eb21087..fdceff5d959e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,7 +25,14 @@ from typing_extensions import Annotated import nemo.lightning as nl -from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, + io, +) from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging @@ -680,6 +687,7 @@ def _setup( tokenizer: Optional[TokenizerType], model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) _log = log or NeMoLogger() if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 7e70a970913e..5c6b71c74797 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -15,6 +15,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from nemo.lightning.pytorch.plugins import MegatronDataSampler class HfDatasetDataModule(pl.LightningDataModule): @@ -24,6 +25,7 @@ def __init__( num_workers=2, pin_memory=True, persistent_workers=True, + seq_length=1024, micro_batch_size=2, global_batch_size=2, pad_token_id=0, @@ -37,6 +39,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.seq_length = seq_length self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.pad_token_id = pad_token_id @@ -58,6 +61,7 @@ def pad_within_micro(batch, pad_token_id): max_len = max(map(len, batch)) return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask'])) return { key: batchify( torch.LongTensor( @@ -67,16 +71,26 @@ def pad_within_micro(batch, pad_token_id): ) ) ) - for key in ['tokens', 'labels'] + for key in keys } + def setup(self, stage: str): + if not self.use_mcore_sampler: + return + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + ) + def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) - dataloader = DataLoader( + return DataLoader( self.dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, @@ -84,20 +98,3 @@ def train_dataloader(self, collate_fn=None): collate_fn=collate_fn, batch_size=self.micro_batch_size, ) - if not self.use_mcore_sampler: - return dataloader - - rank = 0 - world_size = 1 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - return add_megatron_sampler( - dataloader, - self.micro_batch_size, - self.global_batch_size, - dataloader_type=self.mcore_dataloader_type, - rank=rank, - world_size=world_size, - ) diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 7b235d59ee89..a9d18220bcaf 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -273,7 +273,12 @@ def make_vocab_size_divisible_by(vocab_size): base //= 2 return base - output = LlamaConfig( + if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': + # Apply Llama3.1 customize rope scaling + cls = Llama31Config + else: + cls = LlamaConfig + output = cls( num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 9f53ec88bdc8..8f772e3da5b7 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -30,6 +30,8 @@ llama3_70b, llama3_70b_16k, llama3_70b_64k, + llama31_8b, + llama31_70b, llama31_405b, mamba2_1_3b, mamba2_2_7b, @@ -82,6 +84,8 @@ "llama3_70b", "llama3_70b_16k", "llama3_70b_64k", + "llama31_8b", + "llama31_70b", "llama31_405b", "mamba2_130m", "mamba2_370m", diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index 69266737edc9..a060046a8bdf 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -16,6 +16,7 @@ import nemo_run as run import pytorch_lightning as pl +import torch import nemo.lightning as nl from nemo.collections import llm @@ -82,7 +83,7 @@ def default_finetune_recipe( def default_finetune_trainer( tensor_parallelism=1, pipeline_parallelism=1, - pipeline_parallelism_type=None, + pipeline_parallelism_type=torch.bfloat16, virtual_pipeline_parallelism=None, context_parallelism=1, sequence_parallelism=False, @@ -93,6 +94,19 @@ def default_finetune_trainer( limit_val_batches=None, val_check_interval=30, ): + """ + Create a default fine-tuning trainer for any model. + + This function sets up a template for strategy and trainer. + + Args: + See docstrings of MegatronStrategy and Trainer. + + Returns: + run.Config: Config for a finetuning trainer. + + See usages of this in recipes for further details. + """ strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, @@ -125,7 +139,8 @@ def default_finetune_trainer( def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for + https://huggingface.co/{model_id}. This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index 3e54deb0bc1c..8b2111e9f7c4 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -282,7 +282,7 @@ def finetune_recipe( recipe.data.dataset_kwargs = {'add_bos': True} if peft_scheme is None or peft_scheme.lower() == 'none': - recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 2 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index e753c48387c0..31c83713b6e7 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -33,6 +34,7 @@ from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -248,6 +250,9 @@ def finetune_recipe( num_nodes: int = 3, num_gpus_per_node: int = 8, peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, ) -> run.Partial: """ Create a fine-tuning recipe for Llama3.1 405B model. @@ -261,8 +266,11 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,22 +287,116 @@ def finetune_recipe( This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 405B model requires substantial computational resources. """ + if packed_sequence is None: + packed_sequence = performance_mode + + if seq_length is None: + seq_length = 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 12 + elif peft_scheme.lower() == 'lora': + num_nodes = 3 + recipe = default_finetune_recipe( - model(), "meta-llama/Meta-Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node + model(), "meta-llama/Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) - if peft_scheme is None or peft_scheme.lower() == 'none': - assert num_nodes >= 4 recipe.trainer.strategy.tensor_model_parallel_size = 8 - recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.trainer.strategy.pipeline_model_parallel_size = 6 - recipe.trainer.strategy.virtual_pipeline_parallelism = 7 - recipe.data.global_batch_size = 128 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + # Note: limited support. This is not necessarily the most optimized setting + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 6 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py new file mode 100644 index 000000000000..91e4e10c83e6 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -0,0 +1,403 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_70b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 70B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 70B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_70b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config70B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 5, + context_parallelism: int = 2, + sequence_parallelism: bool = True, + num_nodes: int = 4, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 70B model. + + This function sets up the distributed training strategy optimized for the large 70B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_70b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=4, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 70B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_70b + $ nemo llm pretrain --factory "llama31_70b(num_nodes=4, name='my_70b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_70b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 70B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 70B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = None, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 70B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_70b + $ nemo llm finetune --factory "llama31_70b(num_nodes=4, name='my_70b_finetune')" + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_70b_finetune", num_nodes=4) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model + requires substantial computational resources. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 4 + elif peft_scheme.lower() == 'lora': + num_nodes = 1 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.1-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py new file mode 100644 index 000000000000..a4f0082e8535 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -0,0 +1,385 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_8b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 8B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 8B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_8b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config8B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 8B model. + + This function sets up the distributed training strategy optimized for the large 8B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_8b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 8B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_8b + $ nemo llm pretrain --factory "llama31_8b(num_nodes=4, name='my_8b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_8b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 8B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 8B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_8b + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_8b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Meta-Llama-3.1-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index e2156993647d..d43302a0a0ee 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -263,9 +263,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -325,7 +326,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 1030ad8799a1..4f6f6ce17443 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -25,7 +25,6 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -251,9 +250,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -305,7 +305,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py index d4e76abe897e..cb0ba14df868 100644 --- a/nemo/collections/llm/recipes/starcoder_15b.py +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -300,7 +300,7 @@ def finetune_recipe( """ recipe = default_finetune_recipe(model(), "bigcode/starcoder", dir, name, num_nodes, num_gpus_per_node) if peft_scheme is None or peft_scheme.lower() == 'none': - recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 2cc720e148d4..91d3b3f936d0 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -33,7 +33,7 @@ from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -from nemo.lightning.pytorch.trainer import Trainer +from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode(): "ModelCheckpoint", "OptimizerModule", "Trainer", + "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ] diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 3613444b6330..33b7afdf1e76 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -608,7 +608,9 @@ def _io_flatten_object(instance): def _io_unflatten_object(values, metadata): - assert hasattr(_thread_local, "output_dir") + if not hasattr(_thread_local, "output_dir"): + return fdl.Config.__unflatten__(values, metadata) + output_dir = _thread_local.output_dir if len(values) == 1: diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 8b10f9aca50a..a901a3a8842a 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -220,7 +220,7 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None): if callback.dirpath is None: callback.dirpath = Path(log_dir / "checkpoints") if callback.filename is None: - callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}" + callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{consumed_samples}}" ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + "-last" def _handle_task_config(self, task_config, log_dir): diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0d71c49bf198..c97c59ef524d 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy import fiddle as fdl import pytorch_lightning as pl +from pytorch_lightning.loops import _TrainingEpochLoop +from pytorch_lightning.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric @@ -23,8 +26,40 @@ from nemo.lightning.io.mixin import IOMixin, serialization, track_io -class Trainer(pl.Trainer, IOMixin): +class NoValOnRestartTrainingLoop(_TrainingEpochLoop): + """ + Extend the PTL Epoch loop to skip validation when restarting. + This happens when resuming a checkpoint that has already run validation, but loading restores + the training state before validation has run. + """ + + def _should_check_val_fx(self, data_fetcher) -> bool: + if self.skip_val_on_restart: + return False + return super()._should_check_val_fx(data_fetcher) + + def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: + super().load_state_dict(state_dict, prefix) + + self.skip_val_on_restart = True + + def advance(self, data_fetcher: _DataFetcher) -> None: + super().advance(data_fetcher) + + self.skip_val_on_restart = False + +def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: + if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) + trainer.fit_loop.epoch_loop = loop + + +class Trainer(pl.Trainer, IOMixin): def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" if isinstance(obj, (dict, list)): diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 16b6c574d2fa..6a86dacbfefb 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -14,7 +14,7 @@ matplotlib>=3.3.2 #megatron_core>0.6.0 # add back once mcore on pypi is compatible again nltk>=3.6.5 numpy<2 # tensorstore has an implicit compiled dependency on numpy<2 -opencc<1.1.7 +opencc pangu prettytable rapidfuzz diff --git a/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py index 1d69c1aec5eb..12e56e9f1793 100644 --- a/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py +++ b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py @@ -23,7 +23,8 @@ --output_path=your_output_dir \ --model_id=meta-llama/Meta-Llama-3-8B -b. Convert a model weight directory. The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file. +b. Convert a model weight directory. + The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file. Please also provide tokenizer_library and tokenizer_path when loading from weight directory. python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \ --input_path=nemotron3-8b-extracted/model_weights \ @@ -52,8 +53,8 @@ from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir -from nemo.lightning.io.pl import TrainerContext +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir from nemo.utils import logging MODEL_CONFIG_MAPPING = { @@ -66,22 +67,29 @@ "mistralai/Mixtral-8x22B-v0.1": (llm.MixtralModel, llm.MixtralConfig8x22B), "mistralai/Mistral-7B-v0.1": (llm.MistralModel, llm.MistralConfig7B), "nvidia/nemotron-3-8b-base-4k": (llm.NemotronModel, llm.Nemotron3Config8B), - "nemotron4-22b": (llm.NemotronModel, llm.Nemotron4Config22B), + "nemotron4-22b": (llm.NemotronModel, llm.Nemotron3Config22B), "nemotron4-15b": (llm.NemotronModel, llm.Nemotron4Config15B), "nemotron4-340b": (llm.NemotronModel, llm.Nemotron4Config340B), } def get_args(): + """ + Parse the command line arguments. + """ parser = ArgumentParser( - description="Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. This script may download from Hugging Face, make sure you have access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)" + description="""Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. + This script may download from Hugging Face, make sure you have + access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)""" ) parser.add_argument( "--input_path", type=str, default=None, required=True, - help="Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory after untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass in `model_weights` directory.", + help="""Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory a + fter untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass + in `model_weights` directory.""", ) parser.add_argument( "--output_path", type=str, default=None, required=True, help="Path to output NeMo 2.0 directory." @@ -94,7 +102,8 @@ def get_args(): type=str, default=None, required=False, - help="Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails", + help="""Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config + 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails""", ) parser.add_argument( "--tokenizer_library", @@ -108,6 +117,12 @@ def get_args(): def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel: + """ + Get NeMo 2.0 model class from model_id and tokenizer. Use bf16 for NeMo 1.0 ckpts. + + Returns: + llm.GPTModel: NeMo 2.0 model instance + """ if model_id not in MODEL_CONFIG_MAPPING: valid_ids = "\n- ".join([""] + list(MODEL_CONFIG_MAPPING.keys())) @@ -118,6 +133,13 @@ def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel: def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: + """ + Get tokenizer from input .nemo file, or args.tokenizer_path, or Hugging Face. + Only SentencePiece and Hugging Face tokenizers are supported. + + Returns: + AutoTokenizer: tokenizer instance + """ if not input_path.is_dir(): # if .nemo tar with tempfile.TemporaryDirectory() as tmp_dir: # we want to clean up this tmp dir NLPSaveRestoreConnector._unpack_nemo_file(input_path, tmp_dir) @@ -134,7 +156,7 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: tokenizer_lib = args.tokenizer_library or "sentencepiece" if args.tokenizer_library is None: logging.warning( - "You specified tokenizer_path but did not provide tokenizer_library, will default to sentencepiece" + "You specified tokenizer_path but did not provide tokenizer_library using default sentencepiece" ) tokenizer_model = args.tokenizer_path else: # no .nemo config, no tokenizer path specified, grab from HF, reload @@ -148,6 +170,9 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: def main() -> None: + """ + Main function to convert NeMo 1.0 checkpoint to NeMo 2.0 format. + """ tokenizer_tmp_dir = Path("/tmp/nemo_tokenizer") tokenizer_tmp_dir.mkdir(parents=True, exist_ok=True) tokenizer = get_tokenizer(Path(args.input_path), tokenizer_tmp_dir) @@ -196,7 +221,7 @@ def skip_fp8_load(x): logging.info(f"Saving checkpoint to {args.output_path}") model_ckpt['state_dict'] = {k.replace('model', 'module', 1): v for k, v in model_ckpt['state_dict'].items()} trainer.model.module.load_state_dict(model_ckpt['state_dict']) - trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path)) + trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path, is_saving=False)) if getattr(trainer.strategy, "async_save", False): trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 82b5d00bede6..b5250ad5f144 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -22,8 +22,9 @@ from nemo.collections.asr.models import ASRModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint from nemo.collections.asr.parts.mixins import mixins -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.submodules import tdt_beam_decoding from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding, RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils from nemo.core.utils import numba_utils @@ -166,6 +167,39 @@ def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecodi assert len(hyp.timestep['segment']) == segments_count +def check_beam_decoding(test_data_dir, beam_config): + beam_size = beam_config.pop("beam_size", 1) + model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'nvidia/parakeet-tdt_ctc-110m') + + model_config = model.to_config_dict() + durations = list(model_config["model_defaults"]["tdt_durations"]) + + beam = tdt_beam_decoding.BeamTDTInfer( + model.decoder, + model.joint, + beam_size=beam_size, + return_best_hypothesis=False, + durations=durations, + **beam_config, + ) + + enc_out = encoded + enc_len = encoded_len + + with torch.no_grad(): + hyps: rnnt_utils.Hypothesis = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0] + _, all_hyps = decode_text_from_nbest_hypotheses(hyps, model.decoding) + all_hyps = all_hyps[0] + + print("Beam search algorithm :", beam_config['search_type']) + for idx, hyp_ in enumerate(all_hyps): + print("Hyp index", idx + 1, "text :", hyp_.text) + + assert len(hyp_.timestep) > 0 + print("Timesteps", hyp_.timestep) + print() + + class TestRNNTDecoding: @pytest.mark.unit def test_constructor(self): @@ -312,10 +346,10 @@ def test_batched_greedy_decoding_preserve_alignments(self, test_data_dir, loop_l {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "beam_size": 2}, ], ) - def test_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): + def test_rnnt_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): beam_size = beam_config.pop("beam_size", 1) model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small') - beam = beam_decode.BeamRNNTInfer( + beam = rnnt_beam_decoding.BeamRNNTInfer( model.decoder, model.joint, beam_size=beam_size, @@ -442,3 +476,51 @@ def test_char_decoding_compute_timestamps(self, test_data_dir, decoding_strategy hyps, _ = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True) check_char_timestamps(hyps[0], decoding) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "default", + "beam_size": 2, + }, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "beam_size": 2}, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 1, "beam_size": 4}, + ], + ) + def test_tdt_beam_decoding(self, test_data_dir, beam_config): + check_beam_decoding(test_data_dir, beam_config) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "maes", + "maes_num_steps": 2, + "maes_expansion_beta": 1, + "beam_size": 4, + "ngram_lm_alpha": 0.3, + }, + ], + ) + def test_tdt_beam_decoding_with_kenlm(self, test_data_dir, beam_config): + # skipping if kenlm is not installed + pytest.importorskip("kenlm", reason="Skipping test because 'kenlm' is not installed.") + + kenlm_model_path = os.path.join( + test_data_dir, "asr", "kenlm_ngram_lm", "parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa" + ) + beam_config["ngram_lm_model"] = kenlm_model_path + check_beam_decoding(test_data_dir, beam_config) diff --git a/tests/collections/llm/bitexact/mixtral/run.sh b/tests/collections/llm/bitexact/mixtral/run.sh index 0fe9e331b18a..87bf7c382b99 100644 --- a/tests/collections/llm/bitexact/mixtral/run.sh +++ b/tests/collections/llm/bitexact/mixtral/run.sh @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral. # Compare outputs python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \ - "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" + "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0-consumed_samples=20.0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" diff --git a/tests/collections/llm/megatron_mixtral_pretraining.py b/tests/collections/llm/megatron_mixtral_pretraining.py index b4c5b960e0a7..4123c7b37987 100644 --- a/tests/collections/llm/megatron_mixtral_pretraining.py +++ b/tests/collections/llm/megatron_mixtral_pretraining.py @@ -158,7 +158,7 @@ def main(args): ) # Confirm checkpoint directory structure - output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/weights" + output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0-consumed_samples=8.0/weights" assert output_path.exists(), f"Expected {output_path} to exist" assert output_path.is_dir(), f"Expected {output_path} to be a directory" output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata'] diff --git a/tests/lightning/test_nemo_run.py b/tests/lightning/test_nemo_run.py index f91322116824..1371b9adaa8e 100644 --- a/tests/lightning/test_nemo_run.py +++ b/tests/lightning/test_nemo_run.py @@ -30,7 +30,12 @@ ("llama3_70b", "finetune_recipe", "llama3_70b_finetune"), ("llama3_70b_16k", "pretrain_recipe", "llama3_70b_16k_pretrain"), ("llama3_70b_64k", "pretrain_recipe", "llama3_70b_64k_pretrain"), + ("llama31_8b", "pretrain_recipe", "llama31_8b_pretrain"), + ("llama31_8b", "finetune_recipe", "llama31_8b_finetune"), + ("llama31_70b", "pretrain_recipe", "llama31_70b_pretrain"), + ("llama31_70b", "finetune_recipe", "llama31_70b_finetune"), ("llama31_405b", "pretrain_recipe", "llama31_405b_pretrain"), + ("llama31_405b", "finetune_recipe", "llama31_405b_finetune"), ("mistral_7b", "pretrain_recipe", "mistral_pretrain"), ("mistral_7b", "finetune_recipe", "mistral_finetune"), ("mixtral_8x7b", "pretrain_recipe", "mixtral_8x7b_pretrain"), diff --git a/tests/lightning/test_state_restoration.py b/tests/lightning/test_state_restoration.py index 44e0673a1a39..ccc0eed64d56 100644 --- a/tests/lightning/test_state_restoration.py +++ b/tests/lightning/test_state_restoration.py @@ -239,7 +239,7 @@ def run_resume_train(mbs, gbs, num_dev): resume=AutoResume( resume_if_exists=True, resume_ignore_no_checkpoint=False, - resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0/', + resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0-consumed_samples=20.0/', ), ) trainer._teardown() diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama-3/README.rst index bb6171e6f582..3bb1a0896b82 100755 --- a/tutorials/llm/llama-3/README.rst +++ b/tutorials/llm/llama-3/README.rst @@ -17,6 +17,6 @@ This repository contains jupyter notebook tutorials using NeMo Framework for Lla * - `Llama 3.1 Law-Domain LoRA Fine-Tuning and Deployment with NeMo Framework and NVIDIA NIM <./sdg-law-title-generation>`_ - `Law StackExchange `_ - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a pre-requisite, follow the tutorial for `data curation using NeMo Curator `__. - * - `Llama 3.1 WikiText Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ + * - `Llama 3.1 Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ - `WikiText-103-v1 `_ - - Perform pruning and distillation on Llama 3.1 8B Instruct using the WikiText-103-v1 dataset with NeMo Framework. + - Perform pruning and distillation on Llama 3.1 8B using the WikiText-103-v1 dataset with NeMo Framework. diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb new file mode 100644 index 000000000000..1f84dd2719e6 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ab9e2e97-7f10-4353-859e-693842bde465", + "metadata": {}, + "source": [ + "### Step 1: Prepare the dataset\n", + "\n", + "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6505c00b-9eb4-4087-9e49-423f6228e690", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-train.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_train \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb1aa80f-70bc-4dff-8b08-3bff48d9a1c3", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-test.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_test \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42bec54a-94f6-4c87-8e14-2726ef6c2625", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-val.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_val \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "markdown", + "id": "5d77ee8a-e0dc-44f7-b5e8-3b6025d979d7", + "metadata": {}, + "source": [ + "After running the above scripts, you will see the preprocesed `wikitext_tokenized_{train/val/test}_text_document.{idx/bin}`files. These output files will be used in the next step." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb new file mode 100644 index 000000000000..8d08793bbe9a --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "84b146ba-08b6-4adb-a858-8e4294c5e781", + "metadata": {}, + "source": [ + "\n", + "### Step 2: Finetune the teacher on the dataset\n", + "\n", + "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", + "\n", + "We finetune the unpruned model on our dataset to correct the distribution shift across the original dataset the model was trained on. Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that, without correcting for the distribution shift, the teacher provides suboptimal guidance on the dataset when being distilled.\n", + "\n", + "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12007ac8-2fd5-4de8-8964-97821c2198c0", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Set path(s) if different:\n", + "\n", + "MODEL=\"/workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo\"\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_ft\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", + " --config-path /opt/NeMo/examples/nlp/language_modeling/conf/ \\\n", + " --config-name megatron_llama_distill.yaml \\\n", + " \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${MODEL} \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " \\\n", + " ~model.tokenizer \\\n", + " +model.tokenizer='{library: huggingface, type: meta-llama/Meta-Llama-3.1-8B, use_fast: True}' \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.encoder_seq_length=8192 \\\n", + " model.num_layers=32 \\\n", + " model.hidden_size=4096 \\\n", + " model.ffn_hidden_size=14336 \\\n", + " model.num_attention_heads=32 \\\n", + " model.hidden_dropout=0.0 \\\n", + " model.attention_dropout=0.0 \\\n", + " model.apply_query_key_layer_scaling=True \\\n", + " model.normalization='rmsnorm' \\\n", + " model.bias=False \\\n", + " model.activation='fast-swiglu' \\\n", + " model.position_embedding_type='rope' \\\n", + " model.share_embeddings_and_output_weights=False \\\n", + " model.num_query_groups=8 \\\n", + " ++model.scale_positional_embedding=True \\\n", + " ++model.rotary_base=500000.0 \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "3040a993-8423-475f-8bc6-d1dd1ce16a83", + "metadata": {}, + "source": [ + "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb new file mode 100644 index 000000000000..a195c2f3a405 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Prune the finetuned-teacher model to create a student\n", + "In this step, we will explore two methods to prune the finetuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", + "\n", + "In the first method, depth-pruning, we trim the layers of the model." + ] + }, + { + "cell_type": "markdown", + "id": "72fa494e-6268-4044-a1d6-c0518d450cfd", + "metadata": {}, + "source": [ + "#### Step 3.a.: Using depth-pruning \n", + "To depth-prune, we will trim the last 16 layers in the finetined teacher model. For depth-pruning, we would be using the [megatron_gpt_drop_layers](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_drop_layers.py) script. \n", + "\n", + "Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60cae073-a192-4d47-b220-b09736d39a93", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python -m torch.distributed.launch --nproc_per_node=8 \\\n", + " /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \\\n", + " --path_to_nemo \"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " --path_to_save \"/workspace/4b_depth_pruned_model.nemo\" \\\n", + " --tensor_model_parallel_size 8 \\\n", + " --pipeline_model_parallel_size 1 \\\n", + " --gpus_per_node 8 \\\n", + " --drop_layers 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31" + ] + }, + { + "cell_type": "markdown", + "id": "375f298a-0363-4f44-b40c-2c8e9bab7d76", + "metadata": {}, + "source": [ + "Running this script will save the depth-pruned model `4b_depth_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb new file mode 100644 index 000000000000..7d91d36cbb32 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Prune the finetuned-teacher model to create a student\n", + "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads and embedding channels. \n", + "\n", + "Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore." + ] + }, + { + "cell_type": "markdown", + "id": "9207ed14-2f37-4712-88f3-543a128663ac", + "metadata": { + "tags": [] + }, + "source": [ + "#### Step 3.b.: Using width-pruning\n", + "To width-prune the model, we do the following:\n", + "- prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", + "- prune the hidden size from 4096 to 3072.\n", + "- and retrain the attention headcount and number of layers\n", + "\n", + "For width-pruning we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", + "\n", + "We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model.\n", + "\n", + "> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "571d1483-dd4c-403e-b321-293342e7a62a", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!torchrun --nproc-per-node=8 /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_prune.py \\\n", + " model.restore_from_path=\"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " model.tensor_model_parallel_size=1 \\\n", + " model.pipeline_model_parallel_size=8 \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " inference.batch_size=64 \\\n", + " trainer.num_nodes=1 \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=8 \\\n", + " prune.ffn_hidden_size=9216 \\\n", + " prune.num_attention_heads=null \\\n", + " prune.num_query_groups=null \\\n", + " prune.hidden_size=3072 \\\n", + " export.save_path=\"/workspace/4b_width_pruned_model.nemo\"" + ] + }, + { + "cell_type": "markdown", + "id": "e9fb0977-5c02-4ecc-b602-54d74b2e2184", + "metadata": {}, + "source": [ + "Running this script will save the width-pruned model `4b_width_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb new file mode 100644 index 000000000000..ccbe1cbf394b --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "09d30e35-8e9d-4d2e-bd14-738c627a3963", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model. \n", + "\n", + "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "c33cf641-0d27-417f-b3ee-c06701698184", + "metadata": {}, + "source": [ + "#### Step 4.a.: Using depth-pruned student\n", + "While distilling knowledge from the teacher to depth-pruned model, the `STUDENT` model would be `4b_depth_pruned_model.nemo` as produced by the [depth-pruning](./03_a_depth_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d23a01e-4912-47cb-bf21-b4fd72007ec1", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_depth_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_depth_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/depth_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "42d910d9-14dd-44ba-bf2c-0064737c70fa", + "metadata": {}, + "source": [ + "This will create the final distilled model named `depth_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_depth_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 35 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb new file mode 100644 index 000000000000..48e81c96cdcf --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5062f23-c604-479b-9a4e-69989598b131", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", + "In this notebook, we will explore distillation with the width-pruned model as the `STUDENT` model.\n", + "\n", + "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "be7de691-dd1d-4719-9872-98501a22e3c9", + "metadata": {}, + "source": [ + "#### Step 4.b.: Using width-pruned student\n", + "While distilling knowledge from the teacher to width-pruned model, the `STUDENT` model would be `4b_width_pruned_model.nemo` as produced by the [width-pruning](./03_b_width_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0070b526-771a-4a8d-b0ba-ab218b382bd9", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_width_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_width_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/width_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS} \\\n", + " +model.dist_ckpt_load_strictness=log_all" + ] + }, + { + "cell_type": "markdown", + "id": "d9dbc377-e19a-49e0-b245-fa828cca415a", + "metadata": {}, + "source": [ + "This will create the final width-pruned distilled model named `width_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_width_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb new file mode 100644 index 000000000000..0264cc288957 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6c91263b-b312-4ab2-b13f-0ee4b6e8bd0f", + "metadata": {}, + "source": [ + "### Step 5: Display the validation loss\n", + "\n", + "Now that the results are in, let's visualize the validation loss of the two distilled models using the `tensorboard` library. \n", + "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." + ] + }, + { + "cell_type": "markdown", + "id": "b5822d62-8131-4046-8c22-0bf0fce81df7", + "metadata": {}, + "source": [ + "#### Validation Loss using depth-pruned model as student in distillation script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the depth-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a665fe1-df45-4126-8694-f182af113133", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_depth_pruned_student/\" --port=6007" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "db6fcf26-8ae8-40e1-875a-0a10bf85be81", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Depth-Pruned model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Depth-Pruned model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png\", width=400))" + ] + }, + { + "cell_type": "markdown", + "id": "f10041ae-6533-47de-9f76-f97d4469c27a", + "metadata": {}, + "source": [ + "#### Validation Loss using width-pruned model as student in distillation script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the width-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b0c3118-4987-4df3-88bd-fcffdb521c5d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_width_pruned_student/\" --port=6008" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ecd79583-f662-40c6-a690-9f4bb847de4e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Width-Pruned model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Width-Pruned model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png\", width=400))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ab6ed6f-8bc3-4188-919f-7cee842635ed", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama-3/pruning-distillation/README.rst index 9d4207a5c968..34febcffa366 100644 --- a/tutorials/llm/llama-3/pruning-distillation/README.rst +++ b/tutorials/llm/llama-3/pruning-distillation/README.rst @@ -1,18 +1,26 @@ -Llama 3.1 WikiText Pruning and Distillation with NeMo Framework +Llama 3.1 Pruning and Distillation with NeMo Framework ======================================================================================= `Llama 3.1 `_ are open-source large language models by Meta that deliver state-of-the-art performance on popular industry benchmarks. They have been pretrained on over 15 trillion tokens, and support a 128K token context length. They are available in three sizes, 8B, 70B, and 405B, and each size has two variants—base pretrained and instruction tuned. `NVIDIA NeMo Framework `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 to fit your use case. +`NVIDIA TensorRT Model Optimizer `_ is a library (referred to as **Model Optimizer**, or **ModelOpt**) comprising state-of-the-art model optimization techniques including `quantization `_, `sparsity `_, `distillation `_, and `pruning `_ to compress models. + `LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 as described in the `tech report `_. +`How to Prune and Distill Llama-3.1 8B to an NVIDIA Llama-3.1-Minitron 4B Model `_ provides practical and effective structured compression best practices for LLMs that combine depth, width, attention, and MLP pruning with knowledge distillation-based retraining. These strategies are presented in the `Compact Language Models via Pruning and Knowledge Distillation `_ paper. + +`Mistral-NeMo-Minitron 8B Model Delivers Unparalleled Accuracy `_ introduces the Mistral-NeMo-Minitron 8B, a state-of-the-art 8 billion parameter language model created by pruning and distilling the larger Mistral NeMo 12B model. + Objectives ---------- -This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B Instruct** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform a light finetuning procedure on the ``Meta Llama 3.1 8B Instruct`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then depth-pruned to create a trimmed model ``4b_trimmed_model.nemo``. These models will serve as a starting point for distillation to create a final distilled 4B model. +This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform teacher correction by running a light finetuning procedure on the ``Meta Llama 3.1 8B`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will be exploring both pruning techniques which will yield ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo`` respectively. These models will serve as a starting point for distillation to create the final distilled 4B models. We are using models utilizing the ``meta-llama/Meta-Llama-3.1-8B`` tokenizer for this demonstration. +``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable but will be supported in future releases. + Requirements ------------- @@ -31,14 +39,16 @@ Create a pruned and distilled model with NeMo Framework For pruning and distilling the model, you will use the NeMo Framework which is available as a `docker container `_. +``NOTE:`` These notebooks use `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. + -1. Download the `Llama 3.1 8B Instruct .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-instruct-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. +1. Download the `Llama 3.1 8B .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. .. code:: bash - ngc registry model download-version "nvidia/nemo/llama-3_1-8b-instruct-nemo:1.0" + ngc registry model download-version "nvidia/nemo/llama-3_1-8b-nemo:1.0" -2. Run the container using the following command. It is assumed that you have the dataset, notebook(s), and the ``llama-3.1-8b-instruct`` model available in the current directory. If not, mount the appropriate folder to ``/workspace``. +2. Run the container using the following command. It is assumed that you have the dataset, notebook(s), and the ``llama3_1_8b.nemo`` model available in the current directory. If not, mount the appropriate folder to ``/workspace``. .. code:: bash @@ -63,17 +73,38 @@ For pruning and distilling the model, you will use the NeMo Framework which is a jupyter lab --ip 0.0.0.0 --port=8888 --allow-root -4. Then, navigate to `this notebook <./llama3-pruning-distillation-nemofw.ipynb>`_. +4. Then, navigate to `this notebook <./introduction.ipynb>`_ to get started. +This directory contains a list of notebooks which will go over all the steps to create a distilled 4B model. + +:: + + <$pruning_distillation> + └── introduction.ipynb + └── 01_data_preparation.ipynb + └── 02_teacher_finetuning.ipynb + └── 03_a_depth_pruning.ipynb + └── 03_b_width_pruning.ipynb + └── 04_a_distilling_depth_pruned_student.ipynb + └── 04_b_distilling_width_pruned_student.ipynb + └── 05_display_results.ipynb + Results ------------------------------------------------------------------------------ -``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. +``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. + +Here are the validation loss plots over 30 steps of running the training step in the distillation script (at the end of the `notebook <./05_display_results.ipynb>`_). -Here is the validation loss over 30 steps of running the training step in the distillation script (at the end of the `notebook <./llama3-pruning-distillation-nemofw.ipynb>`_). +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png + :width: 400px + :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the depth-pruned model as the student + :align: center -.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_distillation.png + Figure 1: Validation Loss Plot when using the depth-pruned model as the student + +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png :width: 400px - :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script + :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the width-pruned model as the student :align: center - Figure 1: Validation Loss Plot \ No newline at end of file + Figure 2: Validation Loss Plot when using the width-pruned model as the student \ No newline at end of file diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb new file mode 100644 index 000000000000..1a3efc9f5f1e --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "411e6711-60fc-4488-8aa1-c6463cac8695", + "metadata": { + "tags": [] + }, + "source": [ + "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" + ] + }, + { + "cell_type": "markdown", + "id": "03fd1cf4-c67a-4b8d-a5e5-46531be0f991", + "metadata": {}, + "source": [ + "This demonstration showcases performing pruning and distillation on **Llama 3.1-8B** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified 'Good' and 'Featured' articles on Wikipedia. \n", + "\n", + "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed. There are two methods to prune a model: depth-pruning and width-pruning. This workflow will showcase both methods which will yield `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo` respectively, that will serve as a starting point for distillation to the final 4B models. \n", + "\n", + "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", + "\n", + "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", + "\n", + "**Instructions are available in the associated tutorial README to download the model and the container.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5026ce-39f1-43e3-93af-4c4f1e9da1f2", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install --upgrade ipywidgets notebook\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "afe59b07-bb48-4913-90cc-bb416b48196c", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Prerequisites\n", + "Ensure you have the following -\n", + "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d48b81-e978-4894-8ba4-4f183f698bb1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!ls /workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo" + ] + }, + { + "cell_type": "markdown", + "id": "7129d44e-0536-4e62-bdbc-0f1ad44dc84a", + "metadata": {}, + "source": [ + "2. **Set the Hugging Face Access Token**: You can obtain this from your [Hugging Face account](https://huggingface.co/docs/hub/en/security-tokens). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "481417ed-1456-4962-8f67-4350bde1aabd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "245eda8d-c999-431e-9ebc-5c92c4f21f3b", + "metadata": {}, + "source": [ + "3. **Obtain the dataset**: Generate the `wikitext-{train/val/test}.jsonl` splits after loading the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaef2c7d-41f7-41ad-a76a-2d714e9c35de", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# Split into train, test and val files\n", + "\n", + "import json\n", + "import os\n", + "from datasets import load_dataset\n", + "\n", + "# Load the WikiText-103 dataset\n", + "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\")\n", + "\n", + "# Define the destination folder\n", + "data_folder = 'wikitext-data'\n", + "os.makedirs(data_folder, exist_ok=True)\n", + "\n", + "# Define file paths and destination paths\n", + "file_paths = {\n", + " 'train': os.path.join(data_folder, 'wikitext-train.jsonl'),\n", + " 'validation': os.path.join(data_folder, 'wikitext-val.jsonl'),\n", + " 'test': os.path.join(data_folder, 'wikitext-test.jsonl')\n", + "}\n", + "\n", + "# Function to save dataset split to a JSONL file\n", + "def save_to_jsonl(file_path, data):\n", + " with open(file_path, 'w') as file:\n", + " for item in data:\n", + " file.write(json.dumps(item) + '\\n')\n", + "\n", + "# Define splits\n", + "splits = [\"train\", \"validation\", \"test\"]\n", + "\n", + "# Save splits to JSONL files and calculate their sizes\n", + "for split in splits:\n", + " if split in dataset:\n", + " save_to_jsonl(file_paths[split], dataset[split])\n", + " else:\n", + " print(f\"Split {split} not found in the dataset.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "2d0cc359-0598-40aa-af80-9503ecd4dac1", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Step-by-step instructions\n", + "\n", + "This workflow is structured into seven notebooks:\n", + "1. [Prepare the dataset](./01_data_preparation.ipynb)\n", + "2. [Finetune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", + "3. Prune the finetuned-teacher model to create a student \n", + " - 3.a. [Using depth-pruning](./03_a_depth_pruning.ipynb)\n", + " - 3.b. [Using width-pruning](./03_b_width_pruning.ipynb)\n", + "4. Distill knowledge from teacher into student\n", + " - 4.a. [Using depth-pruned student](./04_a_distilling_depth_pruned_student.ipynb)\n", + " - 4.b. [Using width-pruned student](./04_b_distilling_width_pruned_student.ipynb)\n", + "5. [Display the validation loss](./05_display_results.ipynb)\n", + "\n", + "> `NOTE:` We are exploring two methods to prune the finetuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb b/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb deleted file mode 100644 index 8b31ad4de018..000000000000 --- a/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb +++ /dev/null @@ -1,587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "363a6974-810c-41c5-84da-4751a92fb72b", - "metadata": { - "tags": [] - }, - "source": [ - "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" - ] - }, - { - "cell_type": "markdown", - "id": "c6d4ed6d-8ecd-4647-bd0a-e48fec64c199", - "metadata": {}, - "source": [ - "This notebook showcases performing pruning and distillation on **Llama 3.1-8B-Instruct** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. \n", - "\n", - "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B Instruct` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed to create a depth-pruned model `4b_trimmed_model.nemo` that will serve as a starting point for distillation to a final 4B model. \n", - "\n", - "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", - "\n", - "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", - "\n", - "**Instructions are available in the associated tutorial README to download the model and the container.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1d0dc714-5bbf-4266-805a-9841ff486c05", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!pip install --upgrade ipywidgets notebook\n", - "!pip install datasets" - ] - }, - { - "cell_type": "markdown", - "id": "2658505d-7990-40a5-a269-866ddd8a0181", - "metadata": { - "tags": [] - }, - "source": [ - "---\n", - "## Prerequisites\n", - "Ensure you have the following -\n", - "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B Instruct .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a30cfe8a-87a8-4511-be5f-e20d7fe558d4", - "metadata": {}, - "outputs": [], - "source": [ - "!ls /workspace/llama-3_1-8b-instruct-nemo_v1.0" - ] - }, - { - "cell_type": "markdown", - "id": "251a670e-9636-4807-bc98-a91c6137454d", - "metadata": {}, - "source": [ - "2. **Set the Hugging Face Access Token**: You can obtain this from your [Hugging Face account](https://huggingface.co/docs/hub/en/security-tokens). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47d7887d-b582-4a1e-81cd-fdc1be8d9afb", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from huggingface_hub import login\n", - "login(token=\"\")" - ] - }, - { - "cell_type": "markdown", - "id": "b5384e9a-6c40-4454-abe8-413ad9d5db96", - "metadata": {}, - "source": [ - "3. **Obtain the dataset**: Generate the `wikitext-{train/val/test}.jsonl` splits after loading the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b420bd44-3628-45e2-92e7-df38f72a658a", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "# Split into train, test and val files\n", - "\n", - "import json\n", - "import os\n", - "from datasets import load_dataset\n", - "\n", - "# Load the WikiText-103 dataset\n", - "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\")\n", - "\n", - "# Define the destination folder\n", - "data_folder = 'wikitext-data'\n", - "os.makedirs(data_folder, exist_ok=True)\n", - "\n", - "# Define file paths and destination paths\n", - "file_paths = {\n", - " 'train': os.path.join(data_folder, 'wikitext-train.jsonl'),\n", - " 'validation': os.path.join(data_folder, 'wikitext-val.jsonl'),\n", - " 'test': os.path.join(data_folder, 'wikitext-test.jsonl')\n", - "}\n", - "\n", - "# Function to save dataset split to a JSONL file\n", - "def save_to_jsonl(file_path, data):\n", - " with open(file_path, 'w') as file:\n", - " for item in data:\n", - " file.write(json.dumps(item) + '\\n')\n", - "\n", - "# Define splits\n", - "splits = [\"train\", \"validation\", \"test\"]\n", - "\n", - "# Save splits to JSONL files and calculate their sizes\n", - "for split in splits:\n", - " if split in dataset:\n", - " save_to_jsonl(file_paths[split], dataset[split])\n", - " else:\n", - " print(f\"Split {split} not found in the dataset.\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "0185a0a9-904d-46de-a450-db4c84c4cde4", - "metadata": { - "tags": [] - }, - "source": [ - "---\n", - "## Step-by-step instructions\n", - "\n", - "This notebook is structured into five steps:\n", - "1. Prepare the dataset\n", - "2. Finetune the teacher on the dataset\n", - "3. Prune the finetuned-teacher model to create a student\n", - "3. Distill knowledge from teacher into student\n", - "4. Display the validation loss" - ] - }, - { - "cell_type": "markdown", - "id": "cf1d41ff-2cba-4efc-84e3-7d713df0cdb8", - "metadata": {}, - "source": [ - "### Step 1: Prepare the dataset\n", - "\n", - "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2c49c1b8-2447-426c-9f24-bf5956aa2941", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-train.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_train \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "72d14fd7-702f-4b74-a6e5-af3a60eef3a9", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-test.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_test \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1338a1ce-f0e2-4151-ad3d-d34db75ea1bd", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-val.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_val \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "markdown", - "id": "eb80e212-c343-4e51-a92d-184db43df011", - "metadata": {}, - "source": [ - "After running the above scripts, you will see the preprocesed `wikitext_tokenized_{train/val/test}_text_document.{idx/bin}`files. These output files will be used in the next step." - ] - }, - { - "cell_type": "markdown", - "id": "e9f30c0a-4315-4017-b014-add4291a3fde", - "metadata": {}, - "source": [ - "\n", - "### Step 2: Finetune the teacher on the dataset\n", - "\n", - "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", - "\n", - "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c31fd642-0304-43ed-9211-041dc36f22c3", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "%%bash \n", - "\n", - "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", - "\n", - "\n", - "# Set path(s) if different:\n", - "\n", - "MODEL=\"/workspace/llama-3_1-8b-instruct-nemo_v1.0/llama3_1_8b_instruct.nemo\"\n", - "\n", - "# Can change these to accommodate resources:\n", - "\n", - "TENSOR_PARALLEL_SIZE=8\n", - "NODES=1\n", - "MICRO_BATCH_SIZE=4\n", - "\n", - "# Don't change the following:\n", - "\n", - "EXPERIMENT_DIR=\"distill_trainings\"\n", - "EXPERIMENT_NAME=\"megatron_llama_ft\"\n", - "\n", - "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", - "DATA_VAL='wikitext_tokenized_test_text_document'\n", - "DATA_TEST='wikitext_tokenized_val_text_document'\n", - "\n", - "STEPS=30\n", - "GLOBAL_BATCH_SIZE=128\n", - "\n", - "LOG_INTERVAL=1\n", - "VAL_INTERVAL=10\n", - "NUM_VAL_BATCHES=5\n", - "\n", - "LR=1e-4\n", - "MIN_LR=1e-5\n", - "WARMUP_STEPS=2\n", - "\n", - "\n", - "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", - "\n", - "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", - " --config-path /opt/NeMo/examples/nlp/language_modeling/conf/ \\\n", - " --config-name megatron_llama_distill.yaml \\\n", - " \\\n", - " name=${EXPERIMENT_NAME} \\\n", - " \\\n", - " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", - " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", - " exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n", - " \\\n", - " trainer.max_steps=${STEPS} \\\n", - " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", - " trainer.val_check_interval=${VAL_INTERVAL} \\\n", - " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", - " +trainer.num_sanity_val_steps=0 \\\n", - " \\\n", - " trainer.precision=bf16 \\\n", - " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", - " trainer.num_nodes=${NODES} \\\n", - " \\\n", - " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", - " \\\n", - " model.restore_from_path=${MODEL} \\\n", - " \\\n", - " ~model.tokenizer \\\n", - " +model.tokenizer='{library: huggingface, type: meta-llama/Meta-Llama-3.1-8B, use_fast: True}' \\\n", - " \\\n", - " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", - " model.sequence_parallel=True \\\n", - " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", - " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", - " \\\n", - " model.encoder_seq_length=8192 \\\n", - " model.num_layers=32 \\\n", - " model.hidden_size=4096 \\\n", - " model.ffn_hidden_size=14336 \\\n", - " model.num_attention_heads=32 \\\n", - " model.hidden_dropout=0.0 \\\n", - " model.attention_dropout=0.0 \\\n", - " model.apply_query_key_layer_scaling=True \\\n", - " model.normalization='rmsnorm' \\\n", - " model.bias=False \\\n", - " model.activation='fast-swiglu' \\\n", - " model.position_embedding_type='rope' \\\n", - " model.share_embeddings_and_output_weights=False \\\n", - " model.num_query_groups=8 \\\n", - " ++model.scale_positional_embedding=True \\\n", - " ++model.rotary_base=500000.0 \\\n", - " \\\n", - " model.optim.name=distributed_fused_adam \\\n", - " model.optim.lr=${LR} \\\n", - " model.optim.sched.min_lr=${MIN_LR} \\\n", - " model.optim.sched.warmup_steps=${WARMUP_STEPS}" - ] - }, - { - "cell_type": "markdown", - "id": "8aaf604a-efc0-4908-9055-5cf3bb0a05ae", - "metadata": {}, - "source": [ - "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", - "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." - ] - }, - { - "cell_type": "markdown", - "id": "2709ccc0-bbb8-44ba-b00d-15b1dc5d60a7", - "metadata": {}, - "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "\n", - "The next step is to trim the last 16 layers in the finetined teacher model. In this notebook, we are using depth-pruning and would be using the [megatron_gpt_drop_layers](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_drop_layers.py) script. \n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9715a1b-7a23-437f-b5e1-feec8e6c68e0", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python -m torch.distributed.launch --nproc_per_node=8 \\\n", - " /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \\\n", - " --path_to_nemo \"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", - " --path_to_save \"/workspace/4b_trimmed_model.nemo\" \\\n", - " --tensor_model_parallel_size 8 \\\n", - " --pipeline_model_parallel_size 1 \\\n", - " --gpus_per_node 8 \\\n", - " --drop_layers 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31" - ] - }, - { - "cell_type": "markdown", - "id": "1e9553db-9478-4074-9de1-1fa01a0e835c", - "metadata": {}, - "source": [ - "Running this script will save the depth-pruned model `4b_trimmed_model.nemo` to your workspace." - ] - }, - { - "cell_type": "markdown", - "id": "b8ada696-5d77-4113-9d15-a603113fdd58", - "metadata": {}, - "source": [ - "\n", - "### Step 4: Distill knowledge from teacher into student\n", - "\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", - "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model `4b_trimmed_model.nemo`. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61c0c69d-9401-4355-8725-78aa72eee8da", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "%%bash \n", - "\n", - "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", - "\n", - "\n", - "# Can change these to accommodate resources:\n", - "\n", - "TENSOR_PARALLEL_SIZE=8\n", - "NODES=1\n", - "MICRO_BATCH_SIZE=4\n", - "\n", - "# Don't change the following:\n", - "\n", - "EXPERIMENT_DIR=\"distill_trainings\"\n", - "EXPERIMENT_NAME=\"megatron_llama_distill\"\n", - "\n", - "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", - "STUDENT=\"/workspace/4b_trimmed_model.nemo\"\n", - "\n", - "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/distilled_4b_model.nemo\"\n", - "\n", - "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", - "DATA_VAL='wikitext_tokenized_test_text_document'\n", - "DATA_TEST='wikitext_tokenized_val_text_document'\n", - "\n", - "STEPS=30\n", - "GLOBAL_BATCH_SIZE=128\n", - "\n", - "LOG_INTERVAL=1\n", - "VAL_INTERVAL=10\n", - "NUM_VAL_BATCHES=5\n", - "\n", - "LR=1e-4\n", - "MIN_LR=1e-5\n", - "WARMUP_STEPS=2\n", - "\n", - "\n", - "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", - "\n", - "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", - " name=${EXPERIMENT_NAME} \\\n", - " \\\n", - " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", - " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", - " \\\n", - " trainer.max_steps=${STEPS} \\\n", - " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", - " trainer.val_check_interval=${VAL_INTERVAL} \\\n", - " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", - " +trainer.num_sanity_val_steps=0 \\\n", - " \\\n", - " trainer.precision=bf16 \\\n", - " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", - " trainer.num_nodes=${NODES} \\\n", - " \\\n", - " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", - " \\\n", - " model.restore_from_path=${STUDENT} \\\n", - " model.kd_teacher_restore_from_path=${TEACHER} \\\n", - " model.nemo_path=${FINAL_MODEL_PATH} \\\n", - " \\\n", - " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", - " model.sequence_parallel=True \\\n", - " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", - " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", - " \\\n", - " model.optim.name=distributed_fused_adam \\\n", - " model.optim.lr=${LR} \\\n", - " model.optim.sched.min_lr=${MIN_LR} \\\n", - " model.optim.sched.warmup_steps=${WARMUP_STEPS}\n" - ] - }, - { - "cell_type": "markdown", - "id": "fe7034ba-8c69-4edb-8c0f-84fdca43c152", - "metadata": {}, - "source": [ - "This will create the final distilled model named `distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill/checkpoints`.\n", - "> `NOTE:`This script takes at least 35 minutes to run and generate the final distilled model." - ] - }, - { - "cell_type": "markdown", - "id": "c9a66d44-5028-47f9-9df3-9f07692e9461", - "metadata": {}, - "source": [ - "### Step 5: Display the validation loss\n", - "\n", - "Now that the results are in, let's visualize the validation loss of the distilled model using the `tensorboard` library. \n", - "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "be4da14c-c03f-4c28-accd-8f676dbef8a9", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext tensorboard\n", - "%tensorboard --logdir \"distill_trainings/megatron_llama_distill/\" --port=6007" - ] - }, - { - "cell_type": "markdown", - "id": "08c63b80-0f24-4dde-b5d6-11db444726ed", - "metadata": {}, - "source": [ - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "648424fc-6a51-43ca-8f19-6ad05f949054", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import Image, display\n", - "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_distillation.png\", width=400))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}