Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Meta-LLama-3.1-8B-Instruct. #15

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions scripts/config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ export PANZA_SUMMARIZATION_BATCH_SIZE=8 # batch size for summarization.
export PANZA_EVALUATION_BATCH_SIZE=1 # batch size for evaluation. Can safely be set to higher value (e.g., 8) if the GPU has enough capacity.

export MODEL_PRECISION=bf16 # precision at which the base model is stored; options: bf16, fp32, or '4bit'
# export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2"
export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct"
export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2"
# export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct"
# export PANZA_GENERATIVE_MODEL="microsoft/Phi-3-mini-4k-instruct"
# export PANZA_GENERATIVE_MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"

lowercased=$(echo "$PANZA_GENERATIVE_MODEL" | tr '[:upper:]' '[:lower:]')
if [[ ${lowercased} == *llama* ]]; then
Expand Down
95 changes: 95 additions & 0 deletions src/panza/finetuning/configs/llama3.1_4bit_rosa_panza.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
max_seq_len: 512
global_seed: 17
model_name_or_path: #TODO

load_path: # set via bash script to be absolute path to your sparse checkpoint
precision: fp32
hf_save_path: ./checkpoints

max_duration: # TODO
eval_interval: 1
seed: ${global_seed}

global_train_batch_size: #TODO
device_train_microbatch_size: 16
device_eval_batch_size: 16

run_name: # If left blank, will be read from env var $RUN_NAME

model:
name: hf_causal_lm
pretrained: true
pretrained_model_name_or_path: ${model_name_or_path}
max_seq_len: ${max_seq_len}
output_hidden_states: true
weight_bias_dtype: #TODO
compute_dtype: fp32

rosa:
lora_r: #TODO
spa_d: #TODO
lora_alpha: 16
target_modules: 'all-linear'
lora_dropout: 0.05
impl: auto
spa_store_transpose: true
rosa_dtype: fp32
spa_num_grads: 1
grad_acc_mode: mean_squared
grad_4bit_accum: true
mask_load_path: #TODO
mask_save_path: #TODO
terminate_after_mask_generation: #TODO
schedule: #TODO

tokenizer:
name: ${model_name_or_path}
kwargs:
model_max_length: ${max_seq_len}

train_loader:
name: finetuning
dataset:
hf_name: json
split: train
hf_kwargs:
data_files: #TODO
preprocessing_fn: preprocessing:panza_preprocessing_function
max_seq_len: ${max_seq_len}
allow_pad_trimming: false
decoder_only_format: true
shuffle: true
drop_last: false
num_workers: 8
pin_memory: false
prefetch_factor: 2
persistent_workers: true
timeout: 0

scheduler:
name: linear_decay_with_warmup
t_warmup: 20ba
alpha_f: 0

optimizer:
name: decoupled_adamw
lr: # TODO
betas:
- 0.9
- 0.999
eps: 1.0e-8
weight_decay: 0.0

progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: { }
memory_monitor: { }
runtime_estimator: { }

loggers:
wandb: { }