Skip to content

Commit

Permalink
chore(format): run black on main (#497)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jun 28, 2024
1 parent 3985029 commit ade2f46
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 31 deletions.
80 changes: 50 additions & 30 deletions ChatTTS/model/cuda/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
LlamaModel,
LlamaConfig,
)
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.modeling_utils import (
_add_variant,
load_state_dict,
_load_state_dict_into_model,
)
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

Expand All @@ -30,12 +34,16 @@ def replace_decoder(te_decoder_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
original_llama_decoder_cls = (
transformers.models.llama.modeling_llama.LlamaDecoderLayer
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
original_llama_decoder_cls
)


class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
Expand Down Expand Up @@ -64,7 +72,9 @@ def __init__(self, config, *args, **kwargs):
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
te_rope = RotaryPositionEmbedding(
config.hidden_size // config.num_attention_heads
)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

def forward(self, hidden_states, *args, attention_mask, **kwargs):
Expand All @@ -75,7 +85,9 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
return (
super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=self.te_rope_emb,
),
)

Expand All @@ -96,7 +108,9 @@ def __new__(cls, config: LlamaConfig):
return model

@classmethod
def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
def from_pretrained_local(
cls, pretrained_model_name_or_path, *args, config, **kwargs
):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
Expand All @@ -120,16 +134,22 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
is_sharded = True
elif os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
pretrained_model_name_or_path,
subfolder,
_add_variant(WEIGHTS_INDEX_NAME, variant),
)
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
pretrained_model_name_or_path,
subfolder,
_add_variant(WEIGHTS_INDEX_NAME, variant),
)
is_sharded = True
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
raise AssertionError(
"Only sharded PyTorch ckpt format supported at the moment"
)

resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -168,34 +188,34 @@ def _replace_params(hf_state_dict, te_state_dict, config):
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
:
] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]

if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
)
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.query_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]

if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
)
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.key_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]

if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
)
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.value_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]

if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
layer_prefix + "self_attn.o_proj.weight"
].data[:]
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
)

if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
layer_prefix + "post_attention_layernorm.weight"
].data[:]
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
)

# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
Expand All @@ -210,7 +230,7 @@ def _replace_params(hf_state_dict, te_state_dict, config):
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data

if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
layer_prefix + "mlp.down_proj.weight"
].data[:]
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
)
return all_layer_prefixes
5 changes: 4 additions & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ def _build_llama(
if "cuda" in str(device):
try:
from .cuda import TELlamaModel

model = TELlamaModel(llama_config)
self.logger.info("use NVIDIA accelerated TELlamaModel")
except Exception as e:
model = None
self.logger.warn(f"use default LlamaModel for importing TELlamaModel error: {e}")
self.logger.warn(
f"use default LlamaModel for importing TELlamaModel error: {e}"
)
if model is None:
model = LlamaModel(llama_config)
del model.embed_tokens
Expand Down

0 comments on commit ade2f46

Please sign in to comment.