|
1 | 1 | import copy |
2 | 2 | import functools |
3 | | -import os |
4 | 3 | from collections.abc import Sequence |
5 | 4 | from contextlib import nullcontext |
6 | 5 | from typing import Any, List, Optional, Tuple |
@@ -86,22 +85,16 @@ def load_weights(self): |
86 | 85 | assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype." |
87 | 86 | vllm_config_for_load.device_config.device = "cpu" |
88 | 87 |
|
89 | | - if os.getenv("JAX_RANDOM_WEIGHTS", False): |
90 | | - vllm_config_for_load.load_config.load_format = "dummy" |
91 | | - use_random_weights = True |
92 | | - else: |
93 | | - use_random_weights = ( |
94 | | - vllm_config_for_load.load_config.load_format == "dummy") |
95 | | - if use_random_weights: |
| 88 | + if vllm_config_for_load.load_config.load_format == "dummy": |
96 | 89 | logger.info( |
97 | 90 | "Initializing vLLM model with random weights, weight loading skipped." |
98 | 91 | ) |
99 | | - # The DummyModelLoader in vLLM calls torch._sync for torch_xla path when |
100 | | - # it detects the tpu platform, but we don't need it and it causes crash |
101 | | - # without proper setup. |
102 | | - load_context = patch( |
103 | | - "torch._sync", |
104 | | - return_value=None) if use_random_weights else nullcontext() |
| 92 | + # The DummyModelLoader in vLLM calls torch._sync for torch_xla path |
| 93 | + # when it detects the tpu platform, but we don't need it and it |
| 94 | + # causes crash without proper setup. |
| 95 | + load_context = patch("torch._sync", return_value=None) |
| 96 | + else: |
| 97 | + load_context = nullcontext() |
105 | 98 |
|
106 | 99 | # Load the vLLM model and wrap it into a new model whose forward |
107 | 100 | # function can calculate the hidden_state and logits. |
|
0 commit comments