-
Notifications
You must be signed in to change notification settings - Fork 293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supported features #571
Comments
Thank you for the comments! (1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated inference attentions.) ttconnect is super cool, thanks for sending! |
Thanks for the answer. Looking forward to the DPO support. It would of course be fantastic if the HuggingFace datasets could natively be supported. I have never really been able to run large non-streaming datasets from HF on the TPUs (disk-size issues on the VMs), but we have been able to wrap the HF datasets in torch.split_dataset_by_node, to stream on multiple TPUs. Im not sure if I am able to implement something like this into MaxText though. Not really sure on what level it should be implemented. Any chance you support HF datasets in the future? But any way of preprocessing the data before it is split to the TPUs would be extremely useful for running experiments on dataset building. Thats both for sampling or filtering based on a field in the dataset. |
Yes support for HF datasets in MaxText is on the way |
Thank you for tagging me on this. Yes, supporting HuggingFace dataset is in our plan. We have some implementations and are undergoing some perf evaluations to understand it better. I will update here when we have it out. |
Hi @peregilk , HuugingFace dataset is supported now. Please check out https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md. |
Really fantastic! Makes it a lot more convenient. Especially reading jsonlines from the buckets looks great. Do you support all native HF? Like jsonl.gz? |
Yes, jsonl.gz is supported, as well as other formats supported by datasets.load_dataset (https://huggingface.co/docs/datasets/en/loading) |
@aireenmei Is there a more detailed documentation here. I was for instance unable to figure out how to specify the validation set. |
Hi @peregilk, a specific validation set is not supported yet. But this is in our list of items to be worked on. |
@aireenmei Thanks a lot. Really looking forward to testing this. Since this seems to be very related, I am reporting here. Can open an issue if you like: I am training with:
There are 256 files in the directory. Close to the end of the first epoch one of the workers throws this error in
|
Hi @peregilk , this should be the expected behavior. With the current implementation, you may not be able to use all the data in your train files. Say that you have 256 files and you are using v4-64 that has 8 hosts. Each host will read 256/8=32 shards. The i-th host will read the (8*x + i)-th shard (0<=x<32). For exp, host 0 reads shard 0, 8, 16, ..., 248; host 7 reads shard 7, 15, ..., 255 etc. When a host finish their current shard, they move to the next shard assigned to them. But since each shard has slightly different number of examples, the training will stop when the one of the hosts run out of data. For the above exp, if host 0 is the first one to finish it's last shard, 248, it will look for shard 248+8=256, which is not available, and it will results in the error you see. |
Thank @aireenmei. Not sure I understand though. Why would not the logical behaviour here be simply to restart on the first shard that was given to the host when there are no more shards available? Alternatively you would have to duplicate your dataset for training more than one epoch, right? |
I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightfoward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105 |
OK. Makes sense. Thanks. |
@aireenmei I have tried using your validation support for hf-datasets. I am seeing the same issue here, setting hf_eval_files. Even if the number of shard are dividable by the number of the number of workers, it still crashes asking for the next shard. I cant see any way to limit the number of eval steps, so that it does not run out of shards. What am I missing? |
Hi @peregilk indeed this is a bug. I will fix it. Meanwhile this flag (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml#L336) controls eval step that you can use for now, I'll rename it to eval_steps in my next pr for clarity. |
any update on dpo? |
DPO is still underway! Very close now |
That is fantastic news @gobbleturk. Any chance this also would support SimPO. I understand the implementation would be very similar. It would just align perfectly with an ongoing project of mine..;) |
@rdyro is working on DPO implementation, perhaps we can look into SimPO as well |
Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.
A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.
Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.
Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?
Are there plans for implementing DPO/RLHF?
I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.
The text was updated successfully, but these errors were encountered: