Skip to content

Commit

Permalink
updates from PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Feb 22, 2024
1 parent 564e0a6 commit 032eced
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 51 deletions.
10 changes: 9 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,16 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

cfg.axolotl_config_path = config

try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None

capabilities = GPUCapabilities(
bf16=is_torch_bf16_gpu_available(), n_gpu=os.environ.get("WORLD_SIZE", 1)
bf16=is_torch_bf16_gpu_available(),
n_gpu=os.environ.get("WORLD_SIZE", 1),
compute_capability=gpu_version,
)

cfg = validate_config(cfg, capabilities=capabilities)
Expand Down
97 changes: 47 additions & 50 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,25 +395,18 @@ class AxolotlInputConfig(

bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None
float16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
float16: Optional[bool] = None # for non-AMP cases
tf32: Optional[bool] = None
float32: Optional[bool] = None

# torch_dtype: Optional[torch.dtype]

is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)

gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None

unfrozen_parameters: Optional[List[str]] = None

is_preprocess: Optional[bool] = None

sequence_len: int = Field(default=1024)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
Expand Down Expand Up @@ -456,17 +449,24 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None

max_memory: Optional[Union[int, str]] = None
gpu_memory_limit: Optional[str] = None
gpu_memory_limit: Optional[Union[int, str]] = None

chat_template: Optional[Union[str, ChatTemplate]] = None
default_system_message: Optional[str] = None

# INTERNALS - document for now
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None

total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None
sample_packing_eff_est: Optional[float] = None
axolotl_config_path: Optional[str] = None

is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)

@field_validator("datasets")
@classmethod
def check_non_empty_datasets(cls, datasets):
Expand Down Expand Up @@ -500,21 +500,13 @@ def fix_sharegpt_datasets(cls, datasets):

@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, root):
non_empty_count = sum(
1
for field in (
"micro_batch_size",
"gradient_accumulation_steps",
"batch_size",
)
if root.get(field)
)
def check_batch_size_fields(cls, data):
fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
non_empty_count = sum(1 for field in fields if data.get(field))

if non_empty_count < 2:
raise ValueError(
"At least two of [micro_batch_size, gradient_accumulation_steps, batch_size] must be set"
)
return root
raise ValueError(f"At least two of {', '.join(fields)} must be set")
return data

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -547,45 +539,29 @@ def check_gptq_w_revision(cls, data):

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, root):
if root.get("sample_packing") and root.get("xformers_attention"):
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)

return root

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, root):
if (
root.get("sample_packing")
and root.get("sdp_attention")
and (root.get("bfloat16") or root.get("bf16"))
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)

return root
return data

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, root):
if root.get("sample_packing") and root.get("rl"):
def check_sample_packing_w_rl(cls, data):
if data.get("sample_packing") and data.get("rl"):
raise ValueError("`sample_packing: true` does not work with RLHF training")
return root
return data

@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, root):
if root.get("sample_packing") and not root.get("pad_to_sequence_len"):
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return root
return data

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -929,3 +905,24 @@ def check_bf16(self):
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
return self

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)

return data
4 changes: 4 additions & 0 deletions src/axolotl/utils/config/models/internals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""module for gpu capabilities"""
from typing import Optional

from pydantic import BaseModel, Field


Expand All @@ -8,3 +10,5 @@ class GPUCapabilities(BaseModel):
bf16: bool = Field(default=False)
fp8: bool = Field(default=False)
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)

0 comments on commit 032eced

Please sign in to comment.