Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3 support for alpaca dataset #742

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

wukaixingxp
Copy link
Contributor

What does this PR do?

This PR added llama3 support for alpaca dataset so that people can do alpaca finetune, model conversion and inference.

Fixes # (634)

Feature/Issue validation/testing

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • finetune+conversion+inference works with llama3 now
(llama) [[email protected] ~/work/llama-recipes (main)]$ torchrun --nnodes 1 --nproc_per_node 8    ./recipes/quickstart/finetuning/finetuning.py  --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --output_dir ./fsdp_fine_tune_results/output_model_1_8 --dist_checkpoint_root_folder ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8  --enable_fsdp  --num_epochs 1 --batch_size_training 2 --dataset alpaca_dataset
W1021 17:51:38.391000 140142795121664 torch/distributed/run.py:779] 
W1021 17:51:38.391000 140142795121664 torch/distributed/run.py:779] *****************************************
W1021 17:51:38.391000 140142795121664 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1021 17:51:38.391000 140142795121664 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.24it/s]
Loading checkpoint shards:  25%|████████████████████████▊                                                                          | 1/4 [00:00<00:00,  4.45it/s]--> Model meta-llama/Meta-Llama-3.1-8B-Instruct

--> meta-llama/Meta-Llama-3.1-8B-Instruct has 8030.261248 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.50it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.93it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.49it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  5.06it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.64it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.75it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.29it/s]
--> applying fsdp activation checkpointing...
--> Training Set Length = 49402
--> Validation Set Length = 2600
Preprocessing dataset:   0%|                                                                                                           | 0/49402 [00:00<?, ?it/s]--> applying fsdp activation checkpointing...
Preprocessing dataset:   2%|█▋                                                                                             | 883/49402 [00:00<00:15, 3042.66it/s]--> applying fsdp activation checkpointing...
Preprocessing dataset:   0%|▍                                                                                              | 214/49402 [00:00<00:23, 2137.70it/s]--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
Preprocessing dataset:   1%|█                                                                                              | 544/49402 [00:00<00:17, 2820.94it/s]--> applying fsdp activation checkpointing...
Preprocessing dataset:   0%|                                                                                                           | 0/49402 [00:00<?, ?it/s]--> applying fsdp activation checkpointing...
Preprocessing dataset:   2%|██▎                                                                                           | 1192/49402 [00:00<00:15, 3097.85it/s]--> applying fsdp activation checkpointing...
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:14<00:00, 3294.56it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:14<00:00, 3309.27it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3285.11it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3261.47it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3319.79it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3218.45it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3244.37it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3272.76it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3379.44it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 49402/49402 [00:15<00:00, 3216.85it/s]
length of dataset_train 1665
--> Num of Training Set Batches loaded = 104
Preprocessing dataset:  79%|███████████████████████████████████████████████████████████████████████████                    | 2056/2600 [00:00<00:00, 3379.07it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset:  38%|████████████████████████████████████▊                                                           | 998/2600 [00:00<00:00, 3278.66it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3360.82it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Preprocessing dataset:  79%|██████████████████████████████████████████████████████████████████████████▉                    | 2051/2600 [00:00<00:00, 3331.89it/s]Starting epoch 0/1
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset:  52%|█████████████████████████████████████████████████▎                                             | 1348/2600 [00:00<00:00, 3342.51it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3300.57it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3272.28it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3304.81it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset:  90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 2343/2600 [00:00<00:00, 3250.42it/s]NCCL version 2.20.5+cuda12.4
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3198.02it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2600/2600 [00:00<00:00, 3276.55it/s]
--> Num of Validation Set Batches loaded = 10
--> Num of Validation Set Batches loaded = 10
Starting epoch 0/1
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                 | 0/104 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                 | 0/104 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                 | 0/104 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                 | 0/104 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                 | 0/104 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training Epoch: 1/1, step 103/104 completed (loss: 1.224412202835083): 100%|███████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.03s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.383886456489563): 100%|███████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.1475244760513306): 100%|██████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.1846411228179932): 100%|██████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.03s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.167246699333191): 100%|███████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.2066352367401123): 100%|██████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.2416218519210815): 100%|██████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Training Epoch: 1/1, step 103/104 completed (loss: 1.371912956237793): 100%|███████████████████████████████████████████████████| 104/104 [03:30<00:00,  2.02s/it]
Max CUDA memory allocated was 21 GB
Max CUDA memory reserved was 30 GB
Peak active CUDA memory was 22 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 9 GB
evaluating Epoch:   0%|                                                                                                                   | 0/10 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.06it/s]
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.02it/s]
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.03it/s]
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.00it/s]

evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.25it/s]
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.02it/s]
evaluating Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.01it/s]
 eval_ppl=tensor(3.5136, device='cuda:0') eval_epoch_loss=tensor(1.2566, device='cuda:0')
 Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT
=====================================================
 Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT
 Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT

=====================================================



==========================================================================================================

====================================================================================================================================================================================================================




Saving model to /home/kaiwu/work/llama-recipes/fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
Sharded state checkpoint saved to /home/kaiwu/work/llama-recipes/fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct
Checkpoint Time = 18.6865

best eval loss on epoch 1 is 1.2566323280334473
Epoch 1: train_perplexity=4.5941, train_epoch_loss=1.5248, epoch time 211.6873932024464s
training params are saved in /home/kaiwu/work/llama-recipes/fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct/train_params.yaml
Key: avg_train_prep, Value: 4.594069480895996
Key: avg_train_loss, Value: 1.524766206741333
Key: avg_eval_prep, Value: 3.513568878173828
Key: avg_eval_loss, Value: 1.2566323280334473
Key: avg_epoch_time, Value: 211.6873932024464
Key: avg_checkpoint_time, Value: 18.68949384521693
[rank0]:[W1021 17:55:59.339848938 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())


(llama) [[email protected] ~/work/llama-recipes (main)]$ python ./src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct/ --consolidated_model_path ./fsdp_fine_tune_results/fsdp_model_finetune
d_1_8_hf 
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Model name: meta-llama/Meta-Llama-3.1-8B-Instruct
model is loaded from config
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:259: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.
  dist_cp.load_state_dict(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:657: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  torch.load(cast(IO[bytes], file_slice), map_location="cpu"),
Sharded state checkpoint loaded from ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct/
model is loaded from FSDP checkpoints
HuggingFace llama tokenizer has been saved in ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf
HuggingFace model checkpoints has been saved in ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf


(llama) [[email protected] ~/work/llama-recipes (main)]$ python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
use_fast_kernelsFalse
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:06<00:00,  1.08it/s]
User prompt deemed safe.
User prompt:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
the inference time is 7803.556815721095 ms
User input and model output deemed safe.
Model output:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n
You can make a delicious margherita-style pizza. Preheat your oven to 350 degrees F and create the pizza base with your tomatoes. Top it with the creamy mozzarella and the fresh basil. Bake for 10 minutes and enjoy!

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

Copy link
Contributor

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. Please check my comment and add a simple test to confirm correctness.

Comment on lines +70 to +71
else:
last_idx = idx + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to increment the last index in any case?

Suggested change
else:
last_idx = idx + 1
last_idx = idx + 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants