Skip to content

Commit

Permalink
use a one cross entropy func for all shardformer models
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 2, 2024
1 parent 333afb6 commit 483a9fc
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 305 deletions.
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
Expand All @@ -18,6 +18,7 @@
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
Expand Down
43 changes: 42 additions & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss

__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
from colossalai.shardformer.shard import ShardConfig

__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]


class DistCrossEntropy(Function):
Expand Down Expand Up @@ -132,3 +135,41 @@ def cross_entropy_1d(
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)


def dist_cross_entropy(
labels: torch.Tensor,
logits: torch.Tensor,
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Helper to compute cross entropy loss for most shardformer models,
compatible with PP, TP and SP.
"""
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
)
else:
shift_logits = shift_logits.view(-1, vocab_size)
loss = loss_fct(shift_logits, shift_labels)

return loss
55 changes: 12 additions & 43 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -359,30 +359,14 @@ def bloom_for_causal_lm_forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1039,25 +1023,10 @@ def forward(
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down
55 changes: 12 additions & 43 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
Expand All @@ -25,7 +24,7 @@
)
from colossalai.shardformer.shard import ShardConfig

from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy


class CommandPipelineForwards:
Expand Down Expand Up @@ -300,29 +299,9 @@ def command_for_causal_lm_forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -658,24 +637,14 @@ def forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.dtype,
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
47 changes: 7 additions & 40 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -372,27 +372,9 @@ def gpt2_lmhead_model_forward(

hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down Expand Up @@ -1282,24 +1264,9 @@ def forward(
hidden_states = transformer_outputs[0]

lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
49 changes: 7 additions & 42 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from colossalai.shardformer.shard import ShardConfig

from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy


class LlamaPipelineForwards:
Expand Down Expand Up @@ -321,29 +321,9 @@ def llama_for_causal_lm_forward(
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -820,24 +800,9 @@ def forward(
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)

loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
Loading

0 comments on commit 483a9fc

Please sign in to comment.