Skip to content

Commit d852c82

Browse files
committed
Remove JAX_RANDOM_WEIGHTS
- Same functionality can be achieve with vllm argument `--load-format=dummy` - It is better to remove duplicate configs to avoid confusion from users Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent aa58a9c commit d852c82

File tree

7 files changed

+25
-28
lines changed

7 files changed

+25
-28
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,12 @@ steps:
190190
USE_V6E8_QUEUE: "True"
191191
SKIP_ACCURACY_TESTS: "True"
192192
VLLM_MLA_DISABLE: "1"
193-
JAX_RANDOM_WEIGHTS: "True"
194193
agents:
195194
queue: tpu_v6e_8_queue
196195
commands:
197196
- |
198197
if [[ "$$NIGHTLY" == "1" ]]; then
199-
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528
198+
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528 --use-dummy-weight
200199
else
201200
echo "Skipping: NIGHTLY environment variable not set"
202201
exit 0

.buildkite/scripts/run_in_docker.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ exec docker run \
108108
${QUANTIZATION:+-e QUANTIZATION="$QUANTIZATION"} \
109109
${NEW_MODEL_DESIGN:+-e NEW_MODEL_DESIGN="$NEW_MODEL_DESIGN"} \
110110
${USE_V6E8_QUEUE:+-e USE_V6E8_QUEUE="$USE_V6E8_QUEUE"} \
111-
${JAX_RANDOM_WEIGHTS:+-e JAX_RANDOM_WEIGHTS="$JAX_RANDOM_WEIGHTS"} \
112111
${SKIP_ACCURACY_TESTS:+-e SKIP_ACCURACY_TESTS="$SKIP_ACCURACY_TESTS"} \
113112
${VLLM_MLA_DISABLE:+-e VLLM_MLA_DISABLE="$VLLM_MLA_DISABLE"} \
114113
"${IMAGE_NAME}:${BUILDKITE_COMMIT}" \

tests/e2e/benchmarking/mlperf.sh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,12 @@ else
4040
echo "QUANTIZATION is False. Running without quantization."
4141
fi
4242

43-
echo extra_serve_args: "${extra_serve_args[@]}"
44-
4543
root_dir=/workspace
4644
dataset_name=mlperf
4745
dataset_path=""
4846
num_prompts=1000
4947
exit_code=0
48+
use_dummy_weight=false
5049

5150
helpFunction()
5251
{
@@ -57,6 +56,7 @@ helpFunction()
5756
echo -e "\t-p The path to the processed MLPerf dataset (default: None, which will download the dataset)"
5857
echo -e "\t-m A space-separated list of HuggingFace model ids to use (default: Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-0.5B-Instruct, meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-4-Scout-17B-16E-Instruct)"
5958
echo -e "\t-n Number of prompts to use for the benchmark (default: 10)"
59+
echo -e "\t--use-dummy-weight Boolean flag that uses dummy random weight when it's set (default: false)"
6060
exit 1
6161
}
6262

@@ -87,6 +87,11 @@ while [[ "$#" -gt 0 ]]; do
8787
shift
8888
shift
8989
;;
90+
--use-dummy-weight)
91+
use_dummy_weight=true
92+
shift
93+
shift
94+
;;
9095
-h|--help)
9196
helpFunction
9297
;;
@@ -121,6 +126,12 @@ if [ -z "$dataset_path" ]; then
121126
fi
122127
fi
123128

129+
if [ "$use_dummy_weight" = true ]; then
130+
extra_serve_args+=("--load-format=dummy")
131+
fi
132+
133+
echo Using extra_serve_args: "${extra_serve_args[@]}"
134+
124135
echo "Using the dataset at $dataset_path"
125136

126137
cd "$root_dir"/vllm || exit

tests/models/common/test_model_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,13 @@ def test_get_vllm_model(mesh):
252252
assert callable(compute_logits_fn)
253253

254254

255-
@pytest.mark.parametrize("set_in_config", [True, False])
256-
def test_get_vllm_model_random_weights(mesh, set_in_config):
255+
def test_get_vllm_model_random_weights(mesh):
257256
rng = jax.random.PRNGKey(42)
258257

259258
engine_args = EngineArgs(model="Qwen/Qwen3-0.6B")
260259
vllm_config = engine_args.create_engine_config()
261260
vllm_config.model_config.dtype = torch.bfloat16
262-
if set_in_config:
263-
vllm_config.load_config.load_format = "dummy"
264-
else:
265-
os.environ["JAX_RANDOM_WEIGHTS"] = "True"
261+
vllm_config.load_config.load_format = "dummy"
266262

267263
with set_current_vllm_config(vllm_config):
268264
temp_file = tempfile.mkstemp()[1]

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def create_jit_model(
102102
apply_to_abstract_model=False)
103103
return model
104104

105-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
105+
if vllm_config.load_config.load_format == "dummy":
106106
# Create a sharded model with random inited weights.
107107
# TODO: currently Qwen2ForCausalLM is using legacy model implementation
108108
# will merge the random init logic when all model are migrated to new model implementation

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
from collections.abc import Sequence
54
from contextlib import nullcontext
65
from typing import Any, List, Optional, Tuple
@@ -86,22 +85,16 @@ def load_weights(self):
8685
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
8786
vllm_config_for_load.device_config.device = "cpu"
8887

89-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
90-
vllm_config_for_load.load_config.load_format = "dummy"
91-
use_random_weights = True
92-
else:
93-
use_random_weights = (
94-
vllm_config_for_load.load_config.load_format == "dummy")
95-
if use_random_weights:
88+
if vllm_config_for_load.load_config.load_format == "dummy":
9689
logger.info(
9790
"Initializing vLLM model with random weights, weight loading skipped."
9891
)
99-
# The DummyModelLoader in vLLM calls torch._sync for torch_xla path when
100-
# it detects the tpu platform, but we don't need it and it causes crash
101-
# without proper setup.
102-
load_context = patch(
103-
"torch._sync",
104-
return_value=None) if use_random_weights else nullcontext()
92+
# The DummyModelLoader in vLLM calls torch._sync for torch_xla path
93+
# when it detects the tpu platform, but we don't need it and it
94+
# causes crash without proper setup.
95+
load_context = patch("torch._sync", return_value=None)
96+
else:
97+
load_context = nullcontext()
10598

10699
# Load the vLLM model and wrap it into a new model whose forward
107100
# function can calculate the hidden_state and logits.

tpu_inference/platforms/tpu_jax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ class TpuPlatform(Platform):
4949
]
5050

5151
additional_env_vars: list[str] = [
52-
"JAX_RANDOM_WEIGHTS", "PHASED_PROFILING_DIR",
53-
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
52+
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
5453
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
5554
]
5655

0 commit comments

Comments
 (0)