Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move tpu end-to-end test scripts to tpu folder * unify WORKDIR to /deps * Share GCS path between Gemma-7b tests * Add README for llama2-7B * adding script to fix the style and adding modified/fixed files with line length 125 * Move apt install from `rto_setup.sh` to `setup.sh` * Update instructions for installing snap. * Removes batch size from prefill attention calculation. * Fixes for inf testing. * Revert "Fixes for inf testing." This reverts commit b15b1d5. * Fixes * Fix subset of hosts dataloading * inference microbenchmark - allow run specified stages - allow run specific prefill length(s) - delete prefill result - printout prefill result added funcs in max_utils * Update Run_MaxText_via_xpk.md Fixing typo. * inference_microbenchmark: - time prefill only - benchmark prefill and insert * Mark nvidia devtools repo as trusted This is a stopgaps measure to circumvent the nvidia repo's gpg signature issue * Explicitly set AQT Freezer mode in MaxText. PiperOrigin-RevId: 627250589 * Move aqtp pin up * Pre-commit config * Update 128B config on v5e to use qkv_proj_offloaded remat_policy * [MaxText] Rename llama2_7b_single_host_gpu.yml to make it clear that it can be used for any number of host. PiperOrigin-RevId: 627804089 * Split Mixtral test into two scripts * Update jax.tree_map to jax.tree_util.tree_map * change norm sharding fix lint Revert "fix lint" This reverts commit d8dc450. fix lint * Change l2norm to use jnp.sqrt * Fix test_tokenize * Streamlined setup.sh to have fewer apt install calls * loosen tolerance in assert_params_sufficiently_sharded * Enable entropy on multihost CPUs. * Add tests to GPU runner * Replace deprecated np.product with np.prod * fix norm sharding * Add Llama2-70b test * Internal change only. PiperOrigin-RevId: 630446330 * Add more tests for Mixtral * Make some AQT dataclasses to use keyword-only fields (1/N) This cl introduces an temporary decorator that will be temporarily used during this migration. The eventual goal is to enforce kw_only=True in all dataclasses unless it's not feasible, aiming to make AQT less error-prune and improve readability. PiperOrigin-RevId: 631132072 * Reverts e8b53e5 PiperOrigin-RevId: 631465526 * Update tflops calculation * fix sharding on generate cache in prefill results. * Remove async XLA_FLAGS from A3 configs. XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are: --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_async_all_reduce=true Such flags had no impact before the XLA PR as the async collectives were already enabled by default. * Update llama2_7b_gpu.yml PiperOrigin-RevId: 631752008 * Add forward pass logit check test for Llama2-7b * Eval the command string from XPK for GPU script * Remove cases where the deprecated --xla_gpu_simplify_all_fp_conversions is set to its default value. PiperOrigin-RevId: 633645462 * streamline CI test structure * fix pylint fix pylint: Using variable 'p_eval_step' before assignment (#651) * Remove async XLA_FLAGS from A3 configs * Add llama-70b gpu config. PiperOrigin-RevId: 634267313 * Support data input from HuggingFace * Update the NCCL flags for A3+. * add gemma logit test * Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. * fix hf input pipeline * Fix prefill assertion * Remove decode asserts from Gemma test files * add single controller flag * fix OOM issue running inference microbenchmark with llama13b on v5e4 * Add Llama2 13B Tests * Don't clip fp8 stats * Integrate nsys profiler Remove 'enable_profiler' config and add 'profiler' config instead * Add MoE matmul implementation * fix OUTPUT_PATH in v5e/128b.sh * squash * Update flops calculation to active experts in moe * Enable kv cache layout control * Fix Gemma Readme link * Internal change only. Reverts a28f518 PiperOrigin-RevId: 639890999 * Upgrade Pinned Base Image for GPU * Metrics bug: server_lib should be config_lib * Fix MoE matmul scale issue * Removed unused Pallas import from layers/attentions.py PiperOrigin-RevId: 640481280 * Change norm sharding for llama2-7b to fsdp. PiperOrigin-RevId: 640498890 * Copybara import of the project: -- d7d694f by RissyRan <[email protected]>: Fix forward test for Mixtral COPYBARA_INTEGRATE_REVIEW=#679 from google:ranran_fix_forward_test d7d694f PiperOrigin-RevId: 640537456 * Set additional flags for a3 and a3plus * Use run_id instead of sha for docker tag * refactor data input pipeline and add perf data * Add gpt3 175b on v5e config * Pipeline parallelism support (linear only) * Turn on layer scanning for llama2-7b on GPU. This better utilizes recent optimizations to collective approximation in the XLA latency hiding scheduler. PiperOrigin-RevId: 642559284 * reshape q * Add profiler flags to JetStream server Add jetstream config backward compatible * fix tfds instruction * Add vanilla megablox to MoE * Add llama2 70b training config for v5e * base.yml changes circular changes to pipeline.py pyconfig circ changes pipeline parallel tests circular style tree map, half passed tests Total iterations circularized improved iteration comment run all tests test both circular and non-circular circ storage comment circ storage pushing index comment * Account for new mesh axes for llama2-7b, and llama2-70b on GPUs. PiperOrigin-RevId: 643999933 * Sharding the llama2 70b on v5e-16 more efficiently. https://arxiv.org/pdf/2211.05102 https://arxiv.org/pdf/1909.08053 * add compute_axis_order * Add maxengine_server configs to base.yml * Add FSDP + Megablox * Llama3-8b model config * MaxText package * fix data loading from HF hub * Fix llama2-{7,70}b sharding on GPU. PiperOrigin-RevId: 645365795 * Move stage to second axis in mesh Move stage to second axis in mesh * Copybara import of the project: -- 1718b89 by RissyRan <[email protected]>: Refactor permute and unpermute operations COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc PiperOrigin-RevId: 645591567 * Fix Mesh setup for multiprocess CPUs. * add kv_quant_axis * Add a directory check for the . If it fails, attempt to check a path relative to the base config, similar to what is done for model configurations. Minor update Remove the raised exception * Add mistral tokenizer to maxtext/assets * Update the dependencies to prepare for integration of emergency checkpointing Withhold some package versions Update version of typing_extensions * Make broadcasting from one replica to all more memory efficient PiperOrigin-RevId: 646526020 * Inference Microbenchmark Sweep * Fix mesh_axes and data_sharding for LLaMA 2 GPU configs. PiperOrigin-RevId: 646795068 * Allow owners to have any approver Fix AddLabel syntax Fix punctuation * Enable saving using Orbax's emergency checkpoint manager fix data loading from HF hub Add explanation to the emergency checkpoint feature Fix pylint issues Minor changes to the config file resolve conflicts Inference Microbenchmark Sweep Fix mesh_axes and data_sharding for LLaMA 2 GPU configs. PiperOrigin-RevId: 646795068 * Add Llama2 7B, 13B high performance training configs * Load/Save Aqt quantized checkpoint. * modify prefill to return first token * Fix and protect simple_layer Fix and protect simple_layer Fix and protect simple_layer Fix and protect simple_layer * Adding option for int4 quantization to kvcache. * support eval dataset and refactor * Support partial overrides for logical_axis_rules. * Fix simple test step count * Clean up MoE brute force implementation * Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Small formatting Fix merging issues * Add convergence tests on A3 GPU * Update tile size * Handle cases where memstats are not available for the device. Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available. * Fix validation error for other models * Fix decode.py to also use first_token from prefill_call * Add moe perf number * move num_experts pyconfig assertion to fix tests * Cast type for inputs before kernel call * Move sharding overrides to models/ directory. PiperOrigin-RevId: 650994392 * Enable quantization for MoE Matmul implementation * Integrate and test Goodput Monitor with MaxText * Adding Tokens/s/device to the log. * Adding support for mixed precision quantization configs. --------- Co-authored-by: maxtext authors <[email protected]> Co-authored-by: Nina Cai <[email protected]> Co-authored-by: NinaCai <[email protected]> Co-authored-by: michelle-yooh <[email protected]> Co-authored-by: In-Ho Yi <[email protected]> Co-authored-by: A9isha <[email protected]> Co-authored-by: In-Ho Yi <[email protected]> Co-authored-by: ssusie <[email protected]> Co-authored-by: tonyjohnchen <[email protected]> Co-authored-by: Roshani Narasimhan <[email protected]> Co-authored-by: Pate Motter <[email protected]> Co-authored-by: khatwanimohit <[email protected]> Co-authored-by: Morgan Du <[email protected]> Co-authored-by: DongHyun Choi <[email protected]> Co-authored-by: gobbleturk <[email protected]> Co-authored-by: Raymond Zou <[email protected]> Co-authored-by: Bixia Zheng <[email protected]> Co-authored-by: Ran Ran <[email protected]> Co-authored-by: Zhiyu Li <[email protected]> Co-authored-by: Rafi Witten <[email protected]> Co-authored-by: RissyRan <[email protected]> Co-authored-by: Greg Olechwierowicz <[email protected]> Co-authored-by: Junwei Yang <[email protected]> Co-authored-by: Reed Wanderman-Milne <[email protected]> Co-authored-by: Dimitar (Mitko) Asenov <[email protected]> Co-authored-by: aireenmei <[email protected]> Co-authored-by: yangyuwei <[email protected]> Co-authored-by: Abhinav Singh <[email protected]> Co-authored-by: Sadi Kneipp <[email protected]> Co-authored-by: jwyang-google <[email protected]> Co-authored-by: Anfal Siddiqui <[email protected]> Co-authored-by: Brendan Slabe <[email protected]> Co-authored-by: Sergei Lebedev <[email protected]> Co-authored-by: Jon Bolin <[email protected]> Co-authored-by: Zijun Zhou <[email protected]> Co-authored-by: Zhihao Shan <[email protected]> Co-authored-by: Adam O'Brien <[email protected]> Co-authored-by: Vipan Nalla <[email protected]> Co-authored-by: Vipan Nalla <[email protected]> Co-authored-by: Xuefeng Gu <[email protected]> Co-authored-by: Andy Ye <[email protected]> Co-authored-by: Mitali Singh <[email protected]> Co-authored-by: xuefgu <[email protected]> Co-authored-by: Luke Baumann <[email protected]> Co-authored-by: Dipannita Shaw <[email protected]>
- Loading branch information