Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bq draft #1250

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions .github/workflows/RunTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ jobs:
tf_force_gpu_allow_growth: true
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"


clean_up:
if: ${{ always() }} # always execute, regardless of previous jobs or steps.
needs: [gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests]
Expand All @@ -115,9 +114,17 @@ jobs:
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet
- name: Delete TPU image
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet
- name: Notify failed build # creates an issue or modifies last open existing issue for failed build

notify:
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
needs: [gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests]
runs-on: ["self-hosted"]
steps:
- name: Check whether one of the jobs failed
if: ${{ failure() && github.event.pull_request == null }}
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
if: failure() && github.event.pull_request == null
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

- name: Log message if dependent job succeeded
if: ${{ ! (failure() && github.event.pull_request == null) }}
run: echo "Conditions for creating/updating issue not met. Skipping."
2 changes: 1 addition & 1 deletion .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned
- name: build jax stable stack image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable_stack MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable_stack MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack BASEIMAGE=us-central1-docker.pkg.dev/deeplearning-images/jax-stable-stack/gpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
- name: build image with stable stack nightly jax
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_stable_stack_nightly_jax MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
10 changes: 10 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax

# mixture of experts (moe)
num_experts: 1
Expand All @@ -129,6 +131,14 @@ sparse_matmul: True
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss

# deepseek moe
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer (use base_mlp_dim if not DeepSeek style)
first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a bias term is added for routing

# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats.
# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier
Expand Down
38 changes: 38 additions & 0 deletions MaxText/configs/models/deepseek3-671b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek V3 - 671B
# Please note: DeepSeek V3 is not fully support at this moment

base_emb_dim: 7168
base_num_query_heads: 128
base_num_kv_heads: 128
base_mlp_dim: 18432
base_moe_mlp_dim: 2048
base_num_decoder_layers: 61
first_num_dense_layers: 3
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000 # TODO(b/394635939): update after adding tokenizer
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 256
num_experts_per_tok: 8
shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
rope_max_timescale: 10_000
decoder_block: "deepseek"
82 changes: 56 additions & 26 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,32 @@
# pylint: disable=too-many-positional-arguments


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
def prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng_prefill)
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
del prefill_result
return (end - start).total_seconds()


def prefill_benchmark(config, engine, params, tokens, true_length, num_model_params, iters):
def prefill_benchmark(config, engine_prefill, params, tokens, true_length, num_model_params, iters):
"""Handles warmup, running prefill benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng_prefill)
jax.block_until_ready(prefill_result)
del prefill_result

print(f"Prefill benchmark results for length {tokens.size}:\n")
time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters)
time_in_s = prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters)
prefill_average_ms = 1000 * time_in_s / iters
prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config)
tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0
Expand All @@ -82,7 +82,7 @@ def prefill_benchmark(config, engine, params, tokens, true_length, num_model_par


def prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name
config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters, profile_name
):
"""Inner loop for benchmarking prefill and insert step."""
prof = profiler.Profiler(config)
Expand All @@ -91,59 +91,57 @@ def prefill_insert_benchmark_loop(
rng = jax.random.PRNGKey(1234)
for i in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
del prefill_result
decode_state = engine_insert(tokens, true_length, rng_prefill, decode_state, int(i % total_slots), params)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
prof.deactivate()
return (end - start).total_seconds(), decode_state


def prefill_insert_benchmark(config, engine, decode_state, params, total_slots, tokens, true_length, iters):
def prefill_insert_benchmark(config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters):
"""Handles warmup, running insert benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
for i in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
del prefill_result
decode_state = engine_insert(tokens, true_length, rng_prefill, decode_state, int(i % total_slots), params)
jax.block_until_ready(decode_state)

print(f"Prefill and insert benchmark results for length {tokens.size}:\n")
time_in_s, decode_state = prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}"
config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}"
)
prefill_insert_average_ms = time_in_s / iters * 1000.0
print(f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n")
result_dict = {"time_in_ms": prefill_insert_average_ms}
return result_dict, decode_state


def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name):
def ar_benchmark_loop(config, engine_generate, params, decode_state, iters, profile_name):
"""Inner loop for benchmarking ar step."""
prof = profiler.Profiler(config)
prof.activate(optional_postfix=profile_name)
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
for _ in range(iters):
rng, rng_generate = jax.random.split(rng)
decode_state, _ = engine.generate(params, decode_state, rng=rng_generate)
decode_state, _ = engine_generate(params, decode_state, rng_generate)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
prof.deactivate()
return (end - start).total_seconds(), decode_state


def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters):
def ar_benchmark(config, engine_generate, params, decode_state, global_batch_size, cache_size, model_size, iters):
"""Handles warmup, running ar benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
for _ in range(_WARMUP_ITERS):
rng, rng_generate = jax.random.split(rng)
decode_state, _ = engine.generate(params, decode_state, rng=rng_generate)
decode_state, _ = engine_generate(params, decode_state, rng_generate)
jax.block_until_ready(decode_state)

time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress")
time_in_s, decode_state = ar_benchmark_loop(
config, engine_generate, params, decode_state, iters, profile_name="autoregress"
)
seconds_per_step = time_in_s / iters
ar_average_ms = seconds_per_step * 1000
total_throughput = global_batch_size / seconds_per_step
Expand Down Expand Up @@ -224,11 +222,11 @@ def print_results_for_analyze(results):
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['autoregressive']['step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
def summarize_prefill_result(engine_prefill, params, tokens, true_length):
"""Summarize Prefill result."""
print(f"Prefill result of length {tokens.size}:\n")
rng = jax.random.PRNGKey(1234)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng)
jax.block_until_ready(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = max_utils.summarize_pytree_data(
prefill_result["logits"], name="Prefill Logits", raw=True
Expand Down Expand Up @@ -261,7 +259,10 @@ def run_benchmarks(config):
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)

generate_executable, params, decode_state_executable = engine.aot_compile(params, pass_rng_shape=True)
decode_state = decode_state_executable(rng_init_decode)

_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

Expand All @@ -273,19 +274,41 @@ def run_benchmarks(config):
benchmark_results["insert"] = {}
prefill_tokens = {}
prefill_true_lengths = {}
prefill_executable = {}
prefill_insert_executable = {}
i32_scalar = jax.ShapeDtypeStruct((), int)
rng_shape = jax.ShapeDtypeStruct([4], jax.numpy.dtype("uint32"))

for prefill_length in prefill_lengths:
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)

key_shape = jax.ShapeDtypeStruct([prefill_length], jax.numpy.dtype("int32"))
prefill_executable[prefill_length] = (
jax.jit(
engine.prefill_aot,
in_shardings=(engine.param_layouts, None, None, None),
).lower(params, key_shape, i32_scalar, rng_shape)
).compile(compiler_options=None)

prefill_insert_executable[prefill_length] = (
jax.jit(
engine.prefill_insert,
in_shardings=(None, None, None, engine.decode_state_layouts, None, engine.param_layouts),
out_shardings=(engine.decode_state_layouts),
donate_argnames=("decode_state",),
).lower(key_shape, i32_scalar, rng_shape, engine.decode_state_shapes, i32_scalar, params)
).compile(compiler_options=None)

benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
prefill_executable[prefill_length], params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
engine,
prefill_executable[prefill_length],
params,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
Expand All @@ -295,7 +318,7 @@ def run_benchmarks(config):

prefill_insert_time, decode_state = prefill_insert_benchmark(
config,
engine,
prefill_insert_executable[prefill_length],
decode_state,
params,
engine.max_concurrent_decodes,
Expand All @@ -310,7 +333,14 @@ def run_benchmarks(config):

if "generate" in stages_to_benchmark:
benchmark_results["autoregressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters
config,
generate_executable,
params,
decode_state,
engine.max_concurrent_decodes,
cache_size,
model_size,
benchmark_loop_iters,
)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
Expand Down
Loading
Loading