generated from UKPLab/ukp-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
lee
committed
Aug 14, 2024
1 parent
f9da68c
commit ad33c1c
Showing
45 changed files
with
3,207 additions
and
390 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from configs.peft import lora_config, llama_adapter_config, prefix_config | ||
from configs.fsdp import fsdp_config | ||
from configs.training import train_config | ||
from configs.inference import inference_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class samsum_dataset: | ||
dataset: str = "samsum_dataset" | ||
train_split: str = "train" | ||
test_split: str = "validation" | ||
input_length: int = 2048 | ||
|
||
@dataclass | ||
class samsum_dataset2: | ||
dataset: str = "samsum_dataset2" | ||
train_split: str = "train" | ||
test_split: str = "validation" | ||
input_length: int = 2048 | ||
|
||
|
||
@dataclass | ||
class cnndm_dataset: | ||
dataset: str = "cnndm_dataset" | ||
train_split: str = "train" | ||
test_split: str = "validation" | ||
input_length: int = 2048 | ||
|
||
@dataclass | ||
class grammar_dataset: | ||
dataset: str = "grammar_dataset" | ||
train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" | ||
test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" | ||
input_length: int = 2048 | ||
|
||
|
||
@dataclass | ||
class alpaca_dataset: | ||
dataset: str = "alpaca_dataset" | ||
train_split: str = "train" | ||
test_split: str = "val" | ||
data_path: str = "src/llama_recipes/datasets/alpaca_data.json" | ||
|
||
|
||
@dataclass | ||
class custom_dataset: | ||
dataset: str = "custom_dataset" | ||
file: str = "examples/custom_dataset.py" | ||
train_split: str = "train" | ||
test_split: str = "validation" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from dataclasses import dataclass | ||
|
||
from torch.distributed.fsdp import ShardingStrategy | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
@dataclass | ||
class fsdp_config: | ||
mixed_precision: bool=True | ||
use_fp16: bool=False | ||
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD | ||
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # TODO: use StateDictType.SHARDED_STATE_DICT for fsdp | ||
fsdp_activation_checkpointing: bool=False | ||
fsdp_cpu_offload: bool=False | ||
pure_bf16: bool = False | ||
optimizer: str= "AdamW" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class inference_config: | ||
model_name: str=None | ||
peft_model: str=None | ||
quantization: bool=False | ||
use_gumbel: bool=False | ||
max_new_tokens =100 #The maximum numbers of tokens to generate | ||
prompt_file: str=None | ||
seed: int=42 #seed value for reproducibility | ||
do_sample: bool=True #Whether or not to use sampling ; use greedy decoding otherwise. | ||
min_length: int=None #The minimum length of the sequence to be generated, input prompt + min_new_tokens | ||
use_cache: bool=True #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. | ||
top_p: float=0.9 #1.0 # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. | ||
temperature: float=0.01 #1.0 # [optional] The value used to modulate the next token probabilities. | ||
top_k: int=50 # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. | ||
repetition_penalty: float=1.0 #The parameter for repetition penalty. 1.0 means no penalty. | ||
length_penalty: int=1 #[optional] Exponential penalty to the length that is used with beam-based generation. | ||
enable_azure_content_safety: bool=False # Enable safety check with Azure content safety api | ||
enable_sensitive_topics: bool=False # Enable check for sensitive topics using AuditNLG APIs | ||
enable_salesforce_content_safety: bool=True # Enable safety check with Salesforce safety flan t5 | ||
max_padding_length: int=None # the max padding length to be used with tokenizer padding the prompts. | ||
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
output_dir: str = "results" | ||
debugging: bool = False # Enable debugging mode | ||
generation_prompt: bool = True # Set add_generation_prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
@dataclass | ||
class lora_config: | ||
r: int=8 | ||
lora_alpha: int=32 | ||
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) | ||
bias= "none" | ||
task_type: str= "CAUSAL_LM" | ||
lora_dropout: float=0.05 | ||
inference_mode: bool = False | ||
|
||
@dataclass | ||
class llama_adapter_config: | ||
adapter_len: int= 10 | ||
adapter_layers: int= 30 | ||
task_type: str= "CAUSAL_LM" | ||
|
||
@dataclass | ||
class prefix_config: | ||
num_virtual_tokens: int=30 | ||
task_type: str= "CAUSAL_LM" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class train_config: | ||
model_name: str="llama-7b" | ||
enable_fsdp: bool=True | ||
low_cpu_fsdp: bool=False | ||
run_validation: bool=True | ||
batch_size_training: int=1 | ||
gradient_accumulation_steps: int=1 | ||
num_epochs: int=3 | ||
num_workers_dataloader: int=1 | ||
lr: float=1e-4 | ||
weight_decay: float=0.0 | ||
gamma: float= 0.85 | ||
seed: int=42 | ||
use_fp16: bool=False | ||
mixed_precision: bool=True | ||
val_batch_size: int=1 | ||
dataset = "samsum_dataset" | ||
peft_method: str = "lora" # None , llama_adapter, prefix | ||
use_peft: bool=False | ||
output_dir: str = "model_output" | ||
freeze_layers: bool = False | ||
num_freeze_layers: int = 1 | ||
quantization: bool = False | ||
one_gpu: bool = False | ||
save_model: bool = True | ||
dist_checkpoint_root_folder: str="/storage/ukp/work/lee/intel_ukp_llm/intel_ukp_llm/llama-13b" # will be used if using FSDP | ||
dist_checkpoint_folder: str="checkpoints" # will be used if using FSDP | ||
save_optimizer: bool=False # will be used if using FSDP | ||
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
gumbel: bool = False # Enable Gumbel softmax for sampling | ||
gumbel_temperature: float = 1.0 # Gumbel softmax temperature | ||
gumbel_hard: bool = False # Use hard Gumbel softmax | ||
gumbel_noskip_low: int = 2 # Layer to start skipping (low) | ||
gumbel_noskip_high: int = 32 # Layer to stop skipping (high) | ||
debugging: bool = False # Enable debugging mode | ||
debugging_host: str = "localhost" # Debugging host | ||
debugging_port: int = 5678 # Debugging port | ||
gumbel_target: float = 0.8 # Percent of layers that should be used (for calculating gumbel loss) | ||
gumbel_loss_multiplier: float = 50.0 # Simple multiplier for gumbel loss | ||
gumbel_loss_alpha: float = 0.8 # initial weighting factor for the gumbel loss | ||
gumbel_loss_beta: float = 0.0005 # controls the rate at which the weighting factor decreases | ||
use_token_max: bool = False # Use max function over token instead of token mean | ||
use_only_last_token: bool = False # Use only last token for classification | ||
use_only_past_key_values: bool = False # Use only past key values for classification | ||
share_layer: bool = False # Share one gumbel layer across all layers | ||
gumbel_use_simple_classifier: bool = False # Use simple classifier instead of gumbel | ||
gumbel_num_hidden_layers: int = 1 # Number of hidden layers | ||
gradient_clipping_value: float = 1.0 # gradient Clipping value | ||
|
||
|
||
|
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.