diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 385e94f9ad..a156342474 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -344,8 +344,16 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg.axolotl_config_path = config + try: + device_props = torch.cuda.get_device_properties("cuda") + gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) + except: # pylint: disable=bare-except # noqa: E722 + gpu_version = None + capabilities = GPUCapabilities( - bf16=is_torch_bf16_gpu_available(), n_gpu=os.environ.get("WORLD_SIZE", 1) + bf16=is_torch_bf16_gpu_available(), + n_gpu=os.environ.get("WORLD_SIZE", 1), + compute_capability=gpu_version, ) cfg = validate_config(cfg, capabilities=capabilities) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c67c7d5f04..f5407c62e6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -395,25 +395,18 @@ class AxolotlInputConfig( bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO fp16: Optional[bool] = None - bfloat16: Optional[bool] = None - float16: Optional[bool] = None + bfloat16: Optional[bool] = None # for non-AMP cases + float16: Optional[bool] = None # for non-AMP cases tf32: Optional[bool] = None float32: Optional[bool] = None # torch_dtype: Optional[torch.dtype] - is_falcon_derived_model: Optional[bool] = Field(default=False) - is_llama_derived_model: Optional[bool] = Field(default=False) - is_mistral_derived_model: Optional[bool] = Field(default=False) - is_qwen_derived_model: Optional[bool] = Field(default=False) - gradient_checkpointing: Optional[bool] = Field(default=False) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None unfrozen_parameters: Optional[List[str]] = None - is_preprocess: Optional[bool] = None - sequence_len: int = Field(default=1024) sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None @@ -456,17 +449,24 @@ class AxolotlInputConfig( neftune_noise_alpha: Optional[float] = None max_memory: Optional[Union[int, str]] = None - gpu_memory_limit: Optional[str] = None + gpu_memory_limit: Optional[Union[int, str]] = None chat_template: Optional[Union[str, ChatTemplate]] = None default_system_message: Optional[str] = None - # INTERNALS - document for now + # INTERNALS - document for now, generally not set externally + is_preprocess: Optional[bool] = None + total_num_tokens: Optional[int] = None total_supervised_tokens: Optional[int] = None sample_packing_eff_est: Optional[float] = None axolotl_config_path: Optional[str] = None + is_falcon_derived_model: Optional[bool] = Field(default=False) + is_llama_derived_model: Optional[bool] = Field(default=False) + is_mistral_derived_model: Optional[bool] = Field(default=False) + is_qwen_derived_model: Optional[bool] = Field(default=False) + @field_validator("datasets") @classmethod def check_non_empty_datasets(cls, datasets): @@ -500,21 +500,13 @@ def fix_sharegpt_datasets(cls, datasets): @model_validator(mode="before") @classmethod - def check_batch_size_fields(cls, root): - non_empty_count = sum( - 1 - for field in ( - "micro_batch_size", - "gradient_accumulation_steps", - "batch_size", - ) - if root.get(field) - ) + def check_batch_size_fields(cls, data): + fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") + non_empty_count = sum(1 for field in fields if data.get(field)) + if non_empty_count < 2: - raise ValueError( - "At least two of [micro_batch_size, gradient_accumulation_steps, batch_size] must be set" - ) - return root + raise ValueError(f"At least two of {', '.join(fields)} must be set") + return data @model_validator(mode="before") @classmethod @@ -547,45 +539,29 @@ def check_gptq_w_revision(cls, data): @model_validator(mode="before") @classmethod - def check_sample_packing_w_xformers(cls, root): - if root.get("sample_packing") and root.get("xformers_attention"): + def check_sample_packing_w_xformers(cls, data): + if data.get("sample_packing") and data.get("xformers_attention"): raise ValueError( "sample_packing not compatible with xformers_attention. Use flash_attention" ) - return root - - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_sdpa_bf16(cls, root): - if ( - root.get("sample_packing") - and root.get("sdp_attention") - and (root.get("bfloat16") or root.get("bf16")) - ): - # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 - LOG.warning( - "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " - "This may work on H100s." - ) - - return root + return data @model_validator(mode="before") @classmethod - def check_sample_packing_w_rl(cls, root): - if root.get("sample_packing") and root.get("rl"): + def check_sample_packing_w_rl(cls, data): + if data.get("sample_packing") and data.get("rl"): raise ValueError("`sample_packing: true` does not work with RLHF training") - return root + return data @model_validator(mode="before") @classmethod - def hint_sample_packing_padding(cls, root): - if root.get("sample_packing") and not root.get("pad_to_sequence_len"): + def hint_sample_packing_padding(cls, data): + if data.get("sample_packing") and not data.get("pad_to_sequence_len"): LOG.warning( "`pad_to_sequence_len: true` is recommended when using sample_packing" ) - return root + return data @model_validator(mode="before") @classmethod @@ -929,3 +905,24 @@ def check_bf16(self): "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) return self + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_sdpa_bf16(cls, data): + is_sm_90: bool = ( + data["capabilities"] + and data["capabilities"].get("compute_capability") == "sm_90" + ) + if ( + data.get("sample_packing") + and data.get("sdp_attention") + and (data.get("bfloat16") or data.get("bf16")) + and not is_sm_90 + ): + # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 + LOG.warning( + "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " + "This may work on H100s." + ) + + return data diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/config/models/internals/__init__.py index 74ae2afca6..dd742caf45 100644 --- a/src/axolotl/utils/config/models/internals/__init__.py +++ b/src/axolotl/utils/config/models/internals/__init__.py @@ -1,4 +1,6 @@ """module for gpu capabilities""" +from typing import Optional + from pydantic import BaseModel, Field @@ -8,3 +10,5 @@ class GPUCapabilities(BaseModel): bf16: bool = Field(default=False) fp8: bool = Field(default=False) n_gpu: int = Field(default=1) + n_node: int = Field(default=1) + compute_capability: Optional[str] = Field(default=None)