Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lee committed Aug 14, 2024
1 parent f9da68c commit ad33c1c
Show file tree
Hide file tree
Showing 45 changed files with 3,207 additions and 390 deletions.
125 changes: 0 additions & 125 deletions ABOUT_THIS_TEMPLATE.md

This file was deleted.

7 changes: 7 additions & 0 deletions configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from configs.peft import lora_config, llama_adapter_config, prefix_config
from configs.fsdp import fsdp_config
from configs.training import train_config
from configs.inference import inference_config
50 changes: 50 additions & 0 deletions configs/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass


@dataclass
class samsum_dataset:
dataset: str = "samsum_dataset"
train_split: str = "train"
test_split: str = "validation"
input_length: int = 2048

@dataclass
class samsum_dataset2:
dataset: str = "samsum_dataset2"
train_split: str = "train"
test_split: str = "validation"
input_length: int = 2048


@dataclass
class cnndm_dataset:
dataset: str = "cnndm_dataset"
train_split: str = "train"
test_split: str = "validation"
input_length: int = 2048

@dataclass
class grammar_dataset:
dataset: str = "grammar_dataset"
train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv"
test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
input_length: int = 2048


@dataclass
class alpaca_dataset:
dataset: str = "alpaca_dataset"
train_split: str = "train"
test_split: str = "val"
data_path: str = "src/llama_recipes/datasets/alpaca_data.json"


@dataclass
class custom_dataset:
dataset: str = "custom_dataset"
file: str = "examples/custom_dataset.py"
train_split: str = "train"
test_split: str = "validation"
19 changes: 19 additions & 0 deletions configs/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass

from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

@dataclass
class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # TODO: use StateDictType.SHARDED_STATE_DICT for fsdp
fsdp_activation_checkpointing: bool=False
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"

28 changes: 28 additions & 0 deletions configs/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass


@dataclass
class inference_config:
model_name: str=None
peft_model: str=None
quantization: bool=False
use_gumbel: bool=False
max_new_tokens =100 #The maximum numbers of tokens to generate
prompt_file: str=None
seed: int=42 #seed value for reproducibility
do_sample: bool=True #Whether or not to use sampling ; use greedy decoding otherwise.
min_length: int=None #The minimum length of the sequence to be generated, input prompt + min_new_tokens
use_cache: bool=True #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
top_p: float=0.9 #1.0 # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: float=0.01 #1.0 # [optional] The value used to modulate the next token probabilities.
top_k: int=50 # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty: float=1.0 #The parameter for repetition penalty. 1.0 means no penalty.
length_penalty: int=1 #[optional] Exponential penalty to the length that is used with beam-based generation.
enable_azure_content_safety: bool=False # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False # Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety: bool=True # Enable safety check with Salesforce safety flan t5
max_padding_length: int=None # the max padding length to be used with tokenizer padding the prompts.
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
output_dir: str = "results"
debugging: bool = False # Enable debugging mode
generation_prompt: bool = True # Set add_generation_prompt
26 changes: 26 additions & 0 deletions configs/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass, field
from typing import List

@dataclass
class lora_config:
r: int=8
lora_alpha: int=32
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
bias= "none"
task_type: str= "CAUSAL_LM"
lora_dropout: float=0.05
inference_mode: bool = False

@dataclass
class llama_adapter_config:
adapter_len: int= 10
adapter_layers: int= 30
task_type: str= "CAUSAL_LM"

@dataclass
class prefix_config:
num_virtual_tokens: int=30
task_type: str= "CAUSAL_LM"
58 changes: 58 additions & 0 deletions configs/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass


@dataclass
class train_config:
model_name: str="llama-7b"
enable_fsdp: bool=True
low_cpu_fsdp: bool=False
run_validation: bool=True
batch_size_training: int=1
gradient_accumulation_steps: int=1
num_epochs: int=3
num_workers_dataloader: int=1
lr: float=1e-4
weight_decay: float=0.0
gamma: float= 0.85
seed: int=42
use_fp16: bool=False
mixed_precision: bool=True
val_batch_size: int=1
dataset = "samsum_dataset"
peft_method: str = "lora" # None , llama_adapter, prefix
use_peft: bool=False
output_dir: str = "model_output"
freeze_layers: bool = False
num_freeze_layers: int = 1
quantization: bool = False
one_gpu: bool = False
save_model: bool = True
dist_checkpoint_root_folder: str="/storage/ukp/work/lee/intel_ukp_llm/intel_ukp_llm/llama-13b" # will be used if using FSDP
dist_checkpoint_folder: str="checkpoints" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
gumbel: bool = False # Enable Gumbel softmax for sampling
gumbel_temperature: float = 1.0 # Gumbel softmax temperature
gumbel_hard: bool = False # Use hard Gumbel softmax
gumbel_noskip_low: int = 2 # Layer to start skipping (low)
gumbel_noskip_high: int = 32 # Layer to stop skipping (high)
debugging: bool = False # Enable debugging mode
debugging_host: str = "localhost" # Debugging host
debugging_port: int = 5678 # Debugging port
gumbel_target: float = 0.8 # Percent of layers that should be used (for calculating gumbel loss)
gumbel_loss_multiplier: float = 50.0 # Simple multiplier for gumbel loss
gumbel_loss_alpha: float = 0.8 # initial weighting factor for the gumbel loss
gumbel_loss_beta: float = 0.0005 # controls the rate at which the weighting factor decreases
use_token_max: bool = False # Use max function over token instead of token mean
use_only_last_token: bool = False # Use only last token for classification
use_only_past_key_values: bool = False # Use only past key values for classification
share_layer: bool = False # Share one gumbel layer across all layers
gumbel_use_simple_classifier: bool = False # Use simple classifier instead of gumbel
gumbel_num_hidden_layers: int = 1 # Number of hidden layers
gradient_clipping_value: float = 1.0 # gradient Clipping value



21 changes: 0 additions & 21 deletions docs/index.md

This file was deleted.

Loading

0 comments on commit ad33c1c

Please sign in to comment.