diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index f17fad1b6606..331e4972966c 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d +from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -18,6 +18,7 @@ "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", + "dist_cross_entropy", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index a6d19edf5b53..946f6008eba4 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,8 +2,11 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss -__all__ = ["DistCrossEntropy", "cross_entropy_1d"] +from colossalai.shardformer.shard import ShardConfig + +__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] class DistCrossEntropy(Function): @@ -132,3 +135,41 @@ def cross_entropy_1d( dtype: torch.dtype = None, ) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + + +def dist_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + shard_config: ShardConfig, + out_features: int, + vocab_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models, + compatible with PP, TP and SP. + """ + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + ) + else: + shift_logits = shift_logits.view(-1, vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + return loss diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1541436264e9..5e991b682767 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -28,7 +28,7 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -359,30 +359,14 @@ def bloom_for_causal_lm_forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1039,25 +1023,10 @@ def forward( past_key_values = None hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 07a7f6cbf8d3..72f705bc0a75 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -5,7 +5,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( @@ -25,7 +24,7 @@ ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class CommandPipelineForwards: @@ -300,29 +299,9 @@ def command_for_causal_lm_forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -658,24 +637,14 @@ def forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aa75bab115a7..6ecda91c4d35 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,7 +25,7 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -372,27 +372,9 @@ def gpt2_lmhead_model_forward( hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1282,24 +1264,9 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index cf7fd13c898b..54ff8e321e06 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -31,7 +31,7 @@ ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class LlamaPipelineForwards: @@ -321,29 +321,9 @@ def llama_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -820,24 +800,9 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 310c2d8e233a..82e8ef5f9af7 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -19,7 +19,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy logger = logging.get_logger(__name__) @@ -275,29 +275,9 @@ def mistral_for_causal_lm_forward( logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -708,23 +688,9 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index f10860fef558..31bf481de8ec 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,7 +22,7 @@ from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -330,30 +330,14 @@ def opt_for_causal_lm_forward( ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -971,26 +955,9 @@ def forward( ) logits = self.lm_head(outputs[0]).contiguous() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index e0aa5fba4a01..cc74e76bec41 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -32,7 +32,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class Qwen2PipelineForwards: @@ -304,25 +304,7 @@ def qwen2_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy(labels, logits, shard_config, self.config.vocab_size, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -724,26 +706,7 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy(labels, logits, shard_config, self.config.vocab_size, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index c2883d96c16e..53108cd74bf5 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -127,4 +127,5 @@ def main(): if __name__ == "__main__": + print("--------------------------------------") main() diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index b5b50305cc34..c1ea8dedbb81 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -85,6 +85,7 @@ def main(): # Enable gradient checkpointing model.gradient_checkpointing_enable() + print("model._gradient_checkpointing_func:", model._gradient_checkpointing_func) # Set plugin booster_kwargs = {} diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b32eb0d31ea1..88e54176b9fd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -63,9 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for (name, p1), p2 in zip( - llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] - ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -75,11 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - try: - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) - except Exception as e: - print(f"Failed param name: {name}") - raise e + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -253,7 +247,6 @@ def run_llama_test(test_config): except Exception as e: print(f"Failed config: {test_config}") raise e - clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache()