Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

converted mlperf gpt3 ckpt starts with a worse loss #887

Open
gramesh-amd opened this issue Sep 13, 2024 · 26 comments
Open

converted mlperf gpt3 ckpt starts with a worse loss #887

gramesh-amd opened this issue Sep 13, 2024 · 26 comments
Assignees

Comments

@gramesh-amd
Copy link

gramesh-amd commented Sep 13, 2024

Hello,
We converted the paxml checkpoint and resumed training with following config:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "tfds"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
tokenize_eval_data: False
eval_data_column: "ids"
add_bos: False
add_eos: False
eval_split: "validation_tokenized_5662seqs"
eval_interval: 10  # the specific number of train step between eval_step
target_eval_loss: 2.69  # early stop once reaching target eval_loss

enable_checkpointing: True
save_interval_steps: 5

# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 10
model_name: "gpt3-175b" # this model config is unchanged
attention: "cudnn_flash_te"

gradient_accumulation_steps: 1

dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 5
max_target_length: 2048

remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False

dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint

skip_first_n_steps_for_profiler: 3

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                       # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
                       # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
                       # The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
                      ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
                      ['activation_heads', ['tensor','sequence']],
                      ['activation_kv_heads', ['tensor','sequence']],
                      ['activation_length', 'sequence'],
                      ['activation_embed', 'tensor'],
                      ['activation_mlp', 'tensor'],
                      ['activation_kv', 'tensor'],
                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                      ['activation_kv_head_dim', 'tensor'],
                      ['activation_vocab', ['tensor', 'sequence']],
                      ['activation_vocab', 'tensor'],
                      ['activation_vocab', 'sequence'],
                      ['activation_stage','stage'],
                      ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
                      ['vocab', ['tensor', 'autoregressive']],
                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
                      ['embed', ['fsdp', 'sequence']],
                      ['norm', 'fsdp'],
                      ['heads', ['tensor', 'autoregressive']],
                      ['layers', 'stage'],
                      ['kv', []],
                      ['kv_heads', ['tensor', 'autoregressive']],
                      ['kv_head_dim', []],
                      ['cache_batch', []],
                      ['cache_heads', ['autoregressive', 'tensor']],
                      ['cache_kv', []],
                      ['cache_sequence', []],
                    ]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "c4_mlperf"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
eval_split: "validation_tokenized_5662seqs"
python maxtext/MaxText/train.py /dockerx/maxtext/MaxText/configs/gpt3_175b_gpu.yml base_output_directory=/ckpts/paxml/gpt3-conversion run_name=gpt3-conversion steps=4010 scan_layers=true

^ scan_layers set to true in line with how we converted the ckpt

completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 12.945, TFLOP/s/device: 170.297, Tokens/s/device: 158.213, total_weights: 65504, loss: 7.687, perplexity: 2179.917
completed step: 4002, seconds: 11.886, TFLOP/s/device: 185.471, Tokens/s/device: 172.310, total_weights: 65504, loss: 7.739, perplexity: 2297.215
completed step: 4003, seconds: 11.885, TFLOP/s/device: 185.479, Tokens/s/device: 172.318, total_weights: 65504, loss: 7.597, perplexity: 1992.680
completed step: 4004, seconds: 11.931, TFLOP/s/device: 184.759, Tokens/s/device: 171.649, total_weights: 65504, loss: 7.680, perplexity: 2165.097
completed step: 4005, seconds: 11.913, TFLOP/s/device: 185.043, Tokens/s/device: 171.912, total_weights: 65504, loss: 7.663, perplexity: 2128.778
completed step: 4006, seconds: 11.945, TFLOP/s/device: 184.546, Tokens/s/device: 171.451, total_weights: 65504, loss: 7.582, perplexity: 1963.248
completed step: 4007, seconds: 11.913, TFLOP/s/device: 185.048, Tokens/s/device: 171.918, total_weights: 65504, loss: 7.648, perplexity: 2096.574
completed step: 4008, seconds: 12.013, TFLOP/s/device: 183.498, Tokens/s/device: 170.478, total_weights: 65504, loss: 7.524, perplexity: 1851.645
completed step: 4009, seconds: 11.920, TFLOP/s/device: 184.929, Tokens/s/device: 171.807, total_weights: 65504, loss: 7.618, perplexity: 2034.629

^ starts with a very high loss and we expected something closer to 2.77

We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs

@gramesh-amd
Copy link
Author

@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher?

@ZhiyuLi-goog
Copy link
Collaborator

I have never tried on GPU.
To narrow down the root cause, could you try with normal attention?

attention: "dot_product" 

@gramesh-amd
Copy link
Author

with attention: "dot_product" :
completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 39.677, TFLOP/s/device: 277.795, Tokens/s/device: 258.083, total_weights: 327520, loss: 7.638, perplexity: 2076.376
completed step: 4002, seconds: 39.883, TFLOP/s/device: 276.359, Tokens/s/device: 256.749, total_weights: 327520, loss: 7.646, perplexity: 2092.290

I get similar loss as before

@ZhiyuLi-goog
Copy link
Collaborator

ZhiyuLi-goog commented Sep 13, 2024

Oh, could you try something like

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

instead of changing the base.yml?
You can find the exact model yaml setup gpt3-175b.yml and there's some more setup for gpt3-175b.

# these flags might be relevant to output results
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
opt_type: "adam_pax"

I think logits_via_embedding: True should be the most important one.

@gramesh-amd
Copy link
Author

I tested these out. First running

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested)

Its probably something to do with the model weights

@ZhiyuLi-goog
Copy link
Collaborator

I can take a look at full logs if you have.
We should have final effective configs in that log.

@gramesh-amd
Copy link
Author

maxtext_gpt3_logs.txt

Thanks. Here are the logs

@ZhiyuLi-goog
Copy link
Collaborator

Checked the log.
All updated parameters matched and I didn't find anything suspicious.

@gramesh-amd
Copy link
Author

Thanks for checking
yeah its strange that its starting with a bad loss. I also tried testing the tokenizer and it also seems fine

@ZhiyuLi-goog
Copy link
Collaborator

ZhiyuLi-goog commented Sep 17, 2024

The only one I found looks weird is

+ Config param weight_dtype: float32
- Config param weight_dtype: bfloat16

Could you try using weight_dtype as float32 instead of bfloat16?
The activation is calculated as bfloat16 while all parameter and optimizer state should be in float32 format for better convergence.

However, I do not expect such a big gap.

@gramesh-amd
Copy link
Author

gramesh-amd commented Sep 17, 2024

Tried the weight_dtype as float32 as well. Same problem

im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem?

@ZhiyuLi-goog
Copy link
Collaborator

I can take a try in TPU side

By the way, would it be useful to you to print the mean average of each param state after conversion?

@gramesh-amd
Copy link
Author

gramesh-amd commented Sep 20, 2024

im not sure if it will be useful.
We also loaded the pax ckpt directly in paxml and the ckpt starts at the right loss. So at this point, we suspect something is going wrong during conversion

@ZhiyuLi-goog
Copy link
Collaborator

It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours.
If you have some output log in conversion script, I can take a look as well.

We didn't try that in gpu, I guess there might be something differently.

@gramesh-amd
Copy link
Author

great
we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

@ZhiyuLi-goog
Copy link
Collaborator

we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

Great if you can share with us some open gcloud bucket.
By the way, which conversion script are you using? Is the one in mlperf 4.0 submission or the one in maxtext main branch?

@gramesh-amd
Copy link
Author

ok, let me do that
We tried both versions and with both, we are getting the same problem

@ZhiyuLi-goog
Copy link
Collaborator

We tried both versions and with both, we are getting the same problem

Gotcha, thank you for the info!

@ZhiyuLi-goog ZhiyuLi-goog self-assigned this Sep 24, 2024
@gramesh-amd
Copy link
Author

We have created the bucket and will share the access with you soon (I got your google email from one of your commits)

@gramesh-amd
Copy link
Author

gramesh-amd commented Oct 7, 2024

Hello again,
You should finally have access to the bucket containing all the ckpts
image

We have shared three ckpts:
1.gpt3-conversion-forked/ is ckpt created with mlperf fork
2. gpt3-conversion-noscan-cache/
3.gpt3-conversion/

The second and third are both the latest branch - the second was scan_layers=false, and the third is scan_layers=true

let us know if you are able to access and if you have any questions
Thanks

@ZhiyuLi-goog
Copy link
Collaborator

Thank you @gramesh-amd

I will test with your ckpt.

@gabeweisz
Copy link

I wrote a script to look at the checkpoint that we generated and compare it to the original data.
What I found (at least in my initial look) is that for each row has the first 1/4 of its data (starting at element 0) populated, and the rest of it was 0.
This is interesting because we used 4 nodes with 8 GPUs, with intra- and inter-node FSDP.
This makes me think that we have data from the first node, and 0's from the other three nodes.

Any suggestions for how we can confirm that this is what happened and debug it?

@ZhiyuLi-goog
Copy link
Collaborator

@gabeweisz
Great finding and thank you for looking into it!

I am just wondering how did you run the script? I do expect to :

  1. have this script running on each device, this is a no-brainer in TPU, and I am wondering if we should do some tweak in GPU
  2. the output_directory should be accessible to each device

key idea

# each process/device will distribute the weight (fsdp) and each device only keep its own shard
result = jax.make_array_from_single_device_arrays(
    shape,
    sharding,
    [jax.device_put(np.array(arr[index]), d) for d, index in sharding.addressable_devices_indices_map(shape).items()],
)
...
...

# distribute saving to a output directory like a gcs bucket which can be accessible by all devices.
if save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
  max_logging.log(f"saved a checkpoint at step {converted_state.step}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
  checkpoint_manager.wait_until_finished()
  sys.exit()

example

I have tried the script yesterday and worked for me.

# checkpoint loading from a fixed folder
RUN_NAME=ckpt
BASE_OUTPUT_DIR=gs://path/to/output

python MaxText/convert_gpt3_ckpt_from_paxml.py \
  --paxml-ckpt-path=gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 \
  --maxtext-model-name=gpt3-175b \
  --run-name=$RUN_NAME \
  --base-output-directory=$BASE_OUTPUT_DIR

Note the output directory is a gcs bucket which can be accessible by all devices.

@gabeweisz
Copy link

We ran the script in a way very similar to how you ran it - my colleague Gowtham has shared what we did earlier.

When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory.

I did check, and once the script was finished, only node 0 had a checkpoint - none of the others did.

Do you think this caused the issue? If so, does Orbax have a way to work around this?

Another option is that I can try to modify an on-disk checkpoint using tensorstore as in the documentation - there is no real reason why we need to load the checkpoint onto GPUs to convert it from one format to another.

@gabeweisz
Copy link

We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.

@ZhiyuLi-goog
Copy link
Collaborator

When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory.
We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us.

Exactly, I think it should be the root cause.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants