Skip to content

Commit

Permalink
Flash Attention Disable Toggle (Take 2) (#118)
Browse files Browse the repository at this point in the history
For hardware scenarios where flash attention is potentially not supported.

For issue #92

---------

Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti authored Jul 2, 2024
1 parent 56384d5 commit 0ff1ba4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
9 changes: 6 additions & 3 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Standard
from enum import Enum
from typing import Optional
import os

# Third Party
Expand Down Expand Up @@ -96,9 +97,9 @@ class DeepSpeedOptions(BaseModel):
Defaults are all taken from the above docs.
"""

cpu_offload_optimizer: bool = False
cpu_offload_optimizer: Optional[bool] = False
cpu_offload_optimizer_ratio: float = 1
cpu_offload_optimizer_pin_memory: bool = False
cpu_offload_optimizer_pin_memory: Optional[bool] = False

# don't save in deepspeed format as a default
save_samples: int | None = None
Expand Down Expand Up @@ -139,7 +140,7 @@ class TrainingArgs(BaseModel):
is_padding_free: bool
random_seed: int = 42

mock_data: bool = False
mock_data: Optional[bool] = False
mock_data_len: int = 0

deepspeed_options: DeepSpeedOptions = Field(
Expand All @@ -150,6 +151,8 @@ class TrainingArgs(BaseModel):
)
)

disable_flash_attn: Optional[bool] = False

# TODO(osilkin): support quantized full fine-tuning:
# https://github.com/instructlab/training/issues/28
# quantize_dtype: QuantizeDataType = QuantizeDataType.NONE
Expand Down
33 changes: 23 additions & 10 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,29 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training
)

base_model_args = {
"pretrained_model_name_or_path": args.model_name_or_path,
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
}
if not args.disable_flash_attn:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
model = GPTDolomiteForCausalLM.from_pretrained(
path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
**base_model_args,
use_padding_free_transformer=True,
quantization_config=bnb_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

if len(tokenizer) > model.config.vocab_size:
print(
Expand Down Expand Up @@ -613,6 +618,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.is_padding_free:
command.append("--is_granite")

if train_args.disable_flash_attn:
if train_args.is_padding_free:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
command.append("--disable_flash_attn")

if train_args.lora:
command.extend(
[
Expand Down Expand Up @@ -750,6 +762,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
)
parser.add_argument("--disable_flash_attn", action="store_true")
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down

0 comments on commit 0ff1ba4

Please sign in to comment.