Skip to content

Commit bcb4c98

Browse files
committed
Remove JAX_RANDOM_WEIGHTS
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 5b9984d commit bcb4c98

File tree

7 files changed

+21
-21
lines changed

7 files changed

+21
-21
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,12 @@ steps:
190190
USE_V6E8_QUEUE: "True"
191191
SKIP_ACCURACY_TESTS: "True"
192192
VLLM_MLA_DISABLE: "1"
193-
JAX_RANDOM_WEIGHTS: "True"
194193
agents:
195194
queue: tpu_v6e_8_queue
196195
commands:
197196
- |
198197
if [[ "$$NIGHTLY" == "1" ]]; then
199-
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528
198+
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528 --use-dummy-weights
200199
else
201200
echo "Skipping: NIGHTLY environment variable not set"
202201
exit 0

.buildkite/scripts/run_in_docker.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ exec docker run \
108108
${QUANTIZATION:+-e QUANTIZATION="$QUANTIZATION"} \
109109
${NEW_MODEL_DESIGN:+-e NEW_MODEL_DESIGN="$NEW_MODEL_DESIGN"} \
110110
${USE_V6E8_QUEUE:+-e USE_V6E8_QUEUE="$USE_V6E8_QUEUE"} \
111-
${JAX_RANDOM_WEIGHTS:+-e JAX_RANDOM_WEIGHTS="$JAX_RANDOM_WEIGHTS"} \
112111
${SKIP_ACCURACY_TESTS:+-e SKIP_ACCURACY_TESTS="$SKIP_ACCURACY_TESTS"} \
113112
${VLLM_MLA_DISABLE:+-e VLLM_MLA_DISABLE="$VLLM_MLA_DISABLE"} \
114113
"${IMAGE_NAME}:${BUILDKITE_COMMIT}" \

tests/e2e/benchmarking/mlperf.sh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,12 @@ else
4040
echo "QUANTIZATION is False. Running without quantization."
4141
fi
4242

43-
echo extra_serve_args: "${extra_serve_args[@]}"
44-
4543
root_dir=/workspace
4644
dataset_name=mlperf
4745
dataset_path=""
4846
num_prompts=1000
4947
exit_code=0
48+
use_dummy_weights=false
5049

5150
helpFunction()
5251
{
@@ -57,6 +56,7 @@ helpFunction()
5756
echo -e "\t-p The path to the processed MLPerf dataset (default: None, which will download the dataset)"
5857
echo -e "\t-m A space-separated list of HuggingFace model ids to use (default: Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-0.5B-Instruct, meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-4-Scout-17B-16E-Instruct)"
5958
echo -e "\t-n Number of prompts to use for the benchmark (default: 10)"
59+
echo -e "\t--use-dummy-weights Use dummy random weight (default: false)"
6060
exit 1
6161
}
6262

@@ -87,6 +87,11 @@ while [[ "$#" -gt 0 ]]; do
8787
shift
8888
shift
8989
;;
90+
--use-dummy-weightsj)
91+
use_dummy_weights=true
92+
shift
93+
shift
94+
;;
9095
-h|--help)
9196
helpFunction
9297
;;
@@ -121,6 +126,13 @@ if [ -z "$dataset_path" ]; then
121126
fi
122127
fi
123128

129+
if [ "$use_dummy_weights" = true ]; then
130+
extra_serve_args+=("--load-format=dummy")
131+
fi
132+
133+
echo extra_serve_args: "${extra_serve_args[@]}"
134+
135+
124136
echo "Using the dataset at $dataset_path"
125137

126138
cd "$root_dir"/vllm || exit

tests/models/common/test_model_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,17 +254,13 @@ def test_get_vllm_model(mesh):
254254
assert callable(compute_logits_fn)
255255

256256

257-
@pytest.mark.parametrize("set_in_config", [True, False])
258-
def test_get_vllm_model_random_weights(mesh, set_in_config):
257+
def test_get_vllm_model_random_weights(mesh):
259258
rng = jax.random.PRNGKey(42)
260259

261260
engine_args = EngineArgs(model="Qwen/Qwen3-0.6B")
262261
vllm_config = engine_args.create_engine_config()
263262
vllm_config.model_config.dtype = torch.bfloat16
264-
if set_in_config:
265-
vllm_config.load_config.load_format = "dummy"
266-
else:
267-
os.environ["JAX_RANDOM_WEIGHTS"] = "True"
263+
vllm_config.load_config.load_format = "dummy"
268264

269265
with set_current_vllm_config(vllm_config):
270266
temp_file = tempfile.mkstemp()[1]

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def create_jit_model(
103103
apply_to_abstract_model=False)
104104
return model
105105

106-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
106+
if vllm_config.load_config.load_format == "dummy":
107107
# Create a sharded model with random inited weights.
108108
# TODO: currently Qwen2ForCausalLM is using legacy model implementation
109109
# will merge the random init logic when all model are migrated to new model implementation

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
from collections.abc import Sequence
54
from contextlib import nullcontext
65
from typing import Any, List, Optional, Tuple
@@ -91,12 +90,8 @@ def load_weights(self):
9190
# may casue errors. Therefore, we disable it during weight loading.
9291
vllm_config_for_load.parallel_config.enable_expert_parallel = False
9392

94-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
95-
vllm_config_for_load.load_config.load_format = "dummy"
96-
use_random_weights = True
97-
else:
98-
use_random_weights = (
99-
vllm_config_for_load.load_config.load_format == "dummy")
93+
use_random_weights = (
94+
vllm_config_for_load.load_config.load_format == "dummy")
10095
if use_random_weights:
10196
logger.info(
10297
"Initializing vLLM model with random weights, weight loading skipped."

tpu_inference/platforms/tpu_platform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ class TpuPlatform(Platform):
4848
]
4949

5050
additional_env_vars: list[str] = [
51-
"JAX_RANDOM_WEIGHTS", "PHASED_PROFILING_DIR",
52-
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
51+
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
5352
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
5453
]
5554

0 commit comments

Comments
 (0)