Skip to content

Commit

Permalink
use ds model for tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Dec 16, 2024
1 parent a770cce commit f610e67
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
pip install -e .[test]
- name: Running dataset tests
run: |
rm -rf ~/verl-data && git clone --depth 1 https://github.com/eric-haibin-lin/verl-data ~/verl-data
[ ! -d "$HOME/verl-data" ] && git clone --depth 1 https://github.com/eric-haibin-lin/verl-data ~/verl-data
pytest -s -x tests/verl
- name: Running ray tests that need 2 GPUs
run: |
Expand Down
4 changes: 2 additions & 2 deletions examples/sft/gsm8k/run_gemma_2b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if [ "$#" -lt 2 ]; then
fi

nproc_per_node=$1
hdfs_path=$2
save_path=$2

# Shift the arguments so $@ refers to the rest
shift 2
Expand All @@ -23,7 +23,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=32 \
model.partial_pretrain=google/gemma-2b-it \
trainer.default_hdfs_dir=$hdfs_path \
trainer.default_local_dir=$save_path \
trainer.project_name=gsm8k-sft \
trainer.experiment_name=gsm8k-sft-gemma-2b-it \
trainer.total_epochs=3 \
Expand Down
4 changes: 2 additions & 2 deletions tests/verl/utils/dataset/test_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_gsm8k_data():


def test_sft_dataset():
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct')
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct')
set_pad_token_id(tokenizer)
local_path = get_gsm8k_data()
dataset = SFTDataset(parquet_files=local_path,
Expand All @@ -42,4 +42,4 @@ def test_sft_dataset():
data = dataset[0]['input_ids']
output = tokenizer.batch_decode([data])[0]
assert len(output) > 1
assert type(output) == str
assert type(output) == str
2 changes: 1 addition & 1 deletion verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def save_checkpoint(self, step):

path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
# save huggingface model
if self.device_mesh.get_rank() == 0:
if self.device_mesh.get_rank() == 0 and self.config.trainer.default_hdfs_dir:
os.makedirs(path, exist_ok=True)
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
Expand Down
3 changes: 1 addition & 2 deletions verl/utils/logger/aggregate_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
A Ray logger will receive logging info from different processes.
"""

import numbers
from typing import Dict

Expand All @@ -40,4 +39,4 @@ def flush(self):

def log(self, data, step):
if self.print_to_console:
print(concat_dict_to_str(data, step=step), flush=True)
print(concat_dict_to_str(data, step=step), flush=True)

0 comments on commit f610e67

Please sign in to comment.