Skip to content

Commit

Permalink
Update for Recent Changes and Granite Model Class Support (#11)
Browse files Browse the repository at this point in the history
* Update for granite model class support

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Add mixins

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Removing rmsnorm options to avoid optional checks

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Remove TP import

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Add config init

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Remove granite moe

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Remove mixtral

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Remove excess register stuff

Signed-off-by: Mustafa Eyceoz <[email protected]>

---------

Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti authored Nov 1, 2024
1 parent da678a5 commit 5fca4cc
Show file tree
Hide file tree
Showing 59 changed files with 3,055 additions and 1,836 deletions.
80 changes: 80 additions & 0 deletions src/instructlab/dolomite/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,83 @@ class ParamsGroupMethod(Enum):

class GradientCheckpointingMethod(Enum):
block = "block"


class LRDecaySchedule(str, Enum):
constant = "constant"
cosine = "cosine"
exponential = "exponential"
linear = "linear"
power = "power"


class AttentionImplementation(Enum):
"""
Enum class for attention implementation
"""

eager = "eager"
sdpa = "sdpa"
flash_attention_2 = "flash_attention_2"


class MoEImplementation(Enum):
"""
Enum class for MoE implementation
"""

eager = "eager"
scattermoe = "scattermoe"


class DatasetSplit(str, Enum):
"""dataset split"""

train = "train"
val = "val"
test = "test"


class Mode(str, Enum):
"""training / inference mode"""

training = "training"
inference = "inference"
unsharding = "unsharding"
distillation = "distillation"


class TuningMethod(str, Enum):
"""training method"""

pretraining = "pretraining"
full_finetuning = "full_finetuning"
prompt_tuning = "prompt_tuning"
lora = "lora"
distillation = "distillation"


class FP8Backend(str, Enum):
msamp = "msamp"
nvte = "nvte"


class LossMask(str, Enum):
"""Type of loss masking method"""

output_only = "output_only"
no_mask = "no_mask"


class KLDivergenceMethod(str, Enum):
"""Type of KL divergence"""

forward = "forward"
backward = "backward"


class ExperimentsTrackerName(str, Enum):
"""Experiment tracker to use"""

aim = "aim"
wandb = "wandb"
2 changes: 1 addition & 1 deletion src/instructlab/dolomite/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Extracted from https://github.com/ibm-granite/dolomite-engine
# ----------------------------------------------------------------
# Local
from .config import GPTDolomiteConfig
from .models.gpt_dolomite.config import GPTDolomiteConfig
from .model_conversion import export_to_huggingface, import_from_huggingface
from .models import GPTDolomiteForCausalLM, GPTDolomiteModel
from .register_hf import register_model_classes
Expand Down
43 changes: 12 additions & 31 deletions src/instructlab/dolomite/hf_models/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
# ----------------------------------------------------------------
# Extracted from https://github.com/ibm-granite/dolomite-engine
# ----------------------------------------------------------------
# Third Party
from transformers import PretrainedConfig

# Local
from .enums import AttentionHeadType, PositionEmbeddingType
from .enums import AttentionHeadType, InitMethod, PositionEmbeddingType


class GPTDolomiteConfig(PretrainedConfig):
model_type = "gpt_dolomite"
class CommonConfig(PretrainedConfig):
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "n_embd",
Expand All @@ -18,20 +12,15 @@ class GPTDolomiteConfig(PretrainedConfig):
"num_hidden_layers": "n_layer",
}

# NOTE: initializer range is kept for backward compatiblity
# but it is not used anymore
# : also rope_scaling is not used anymore but kept for
# same reason.

def __init__(
self,
vocab_size: int = 50257,
n_positions: int = 1024,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
num_key_value_heads: int = None,
n_inner: int = None,
num_key_value_heads: int | None = None,
n_inner: int | None = None,
activation_function: str = "gelu_pytorch_tanh",
attention_head_type: str = "mqa",
resid_pdrop: float = 0.1,
Expand All @@ -41,20 +30,19 @@ def __init__(
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
scale_attn_weights: bool = True,
attention_multiplier: float = None,
attention_multiplier: float | None = None,
use_cache: bool = True,
bos_token_id: int = 50256,
eos_token_id: int = 50256,
pad_token_id: int = 50256,
attention_softmax_in_fp32: bool = True,
scale_attention_softmax_in_fp32: bool = True,
add_bias: bool = True,
position_embedding_type: str = "learned_absolute",
rope_theta: int = 10000,
rope_scaling: dict = None,
m_emb: float = None,
m_width: float = None,
m_residual: float = None,
rope_scaling: dict | None = None,
m_emb: float | None = None,
m_width: float | None = None,
m_residual: float | None = None,
init_method: str = "normal",
upcast_logits_for_loss: bool = False,
**kwargs,
Expand All @@ -78,7 +66,6 @@ def __init__(
self.attention_multiplier = attention_multiplier
self.use_cache = use_cache
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
self.position_embedding_type = position_embedding_type
self.add_bias = add_bias
self.rope_theta = rope_theta
Expand All @@ -93,6 +80,7 @@ def __init__(
assert self.scale_attn_weights

# check if enums are valid
init_method = InitMethod(init_method)
attention_head_type = AttentionHeadType(attention_head_type)
position_embedding_type = PositionEmbeddingType(position_embedding_type)

Expand All @@ -110,9 +98,7 @@ def __init__(
if self.num_key_value_heads is None:
self.num_key_value_heads = 1

assert (
self.num_key_value_heads == 1
), "MultiQueryAttention should have 1 head for keys and values"
assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values"
elif attention_head_type == AttentionHeadType.gqa:
assert (
self.num_key_value_heads is not None
Expand All @@ -122,9 +108,4 @@ def __init__(
self.n_head % self.num_key_value_heads == 0
), "GroupedQueryAttention should have more than 1 head for keys and values"

super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
**kwargs,
)
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
1 change: 1 addition & 0 deletions src/instructlab/dolomite/hf_models/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DEFAULT_NORMALIZATION_IMPLEMENTATION = "torch"
9 changes: 5 additions & 4 deletions src/instructlab/dolomite/hf_models/enums.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# ----------------------------------------------------------------
# Extracted from https://github.com/ibm-granite/dolomite-engine
# ----------------------------------------------------------------
# Standard
from enum import Enum


class InitMethod(Enum):
normal = "normal"
mup = "mup"


class PositionEmbeddingType(Enum):
"""
Enum class for position embeddings
Expand Down
4 changes: 4 additions & 0 deletions src/instructlab/dolomite/hf_models/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin
#from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP
from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin
#from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP
2 changes: 2 additions & 0 deletions src/instructlab/dolomite/hf_models/mixins/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import BaseModelMixin, PreTrainedModelMixin
from .main import CausalLMModelMixin
Loading

0 comments on commit 5fca4cc

Please sign in to comment.