Skip to content

Commit 5b9984d

Browse files
authored
[Refactor] Remove jax suffix from file names (#1070)
1 parent 582c6e9 commit 5b9984d

26 files changed

+96
-98
lines changed
File renamed without changes.

tests/runner/test_input_batch_jax.py renamed to tests/runner/test_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from vllm.sampling_params import SamplingParams
44

5-
from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
5+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
66

77
# Default parameters for creating InputBatch instances in tests
88
MAX_NUM_REQS = 8

tests/runner/test_kv_cache_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from vllm.v1.request import Request
1717

1818
from tpu_inference import utils as common_utils
19-
from tpu_inference.runner.input_batch_jax import CachedRequestState
20-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
19+
from tpu_inference.runner.input_batch import CachedRequestState
20+
from tpu_inference.runner.tpu_runner import TPUModelRunner
2121

2222

2323
class TestKVCacheManager:
@@ -37,7 +37,7 @@ def setup_method(self):
3737
with patch('jax.devices', return_value=self.mock_devices), \
3838
patch('jax.make_mesh', return_value=self.mock_mesh), \
3939
patch('jax.random.key', return_value=self.mock_rng_key), \
40-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=MagicMock()):
40+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()):
4141

4242
model_config = ModelConfig(tokenizer_mode="auto",
4343
trust_remote_code=False,

tests/runner/test_multimodal_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from vllm.sampling_params import SamplingType
1313
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
1414

15-
from tpu_inference.runner.input_batch_jax import CachedRequestState
16-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
15+
from tpu_inference.runner.input_batch import CachedRequestState
16+
from tpu_inference.runner.tpu_runner import TPUModelRunner
1717

1818

1919
class TestMultiModalManager:
@@ -29,8 +29,8 @@ def setup_method(self):
2929
with patch('jax.devices', return_value=self.mock_devices), \
3030
patch('jax.make_mesh', return_value=self.mock_mesh), \
3131
patch('jax.random.key', return_value=self.mock_rng_key), \
32-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=MagicMock()), \
33-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh', return_value=self.mock_mesh):
32+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
33+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
3434

3535
model_config = ModelConfig(tokenizer_mode="auto",
3636
trust_remote_code=False,

tests/runner/test_speculative_decoding_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from vllm.sampling_params import SamplingType
99
from vllm.v1.outputs import DraftTokenIds
1010

11-
from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
11+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
1212
from tpu_inference.runner.speculative_decoding_manager import \
1313
SpecDecodeMetadata
14-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
14+
from tpu_inference.runner.tpu_runner import TPUModelRunner
1515
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
1616

1717

@@ -28,8 +28,8 @@ def setup_method(self):
2828
with patch('jax.devices', return_value=self.mock_devices), \
2929
patch('jax.make_mesh', return_value=self.mock_mesh), \
3030
patch('jax.random.key', return_value=self.mock_rng_key), \
31-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=MagicMock()), \
32-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh', return_value=self.mock_mesh):
31+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
32+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
3333

3434
model_config = ModelConfig(tokenizer_mode="auto",
3535
trust_remote_code=False,

tests/runner/test_structured_decoding_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
SchedulerConfig, SpeculativeConfig, VllmConfig)
88
from vllm.sampling_params import SamplingType
99

10-
from tpu_inference.runner.input_batch_jax import CachedRequestState
11-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
10+
from tpu_inference.runner.input_batch import CachedRequestState
11+
from tpu_inference.runner.tpu_runner import TPUModelRunner
1212

1313

1414
class TestStructuredDecodingManager:
@@ -26,8 +26,8 @@ def setup_method(self):
2626
with patch('jax.devices', return_value=self.mock_devices), \
2727
patch('jax.make_mesh', return_value=self.mock_mesh), \
2828
patch('jax.random.key', return_value=self.mock_rng_key), \
29-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=MagicMock()), \
30-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh', return_value=self.mock_mesh):
29+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
30+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
3131

3232
model_config = ModelConfig(tokenizer_mode="auto",
3333
trust_remote_code=False,

tests/runner/test_tpu_jax_runner.py renamed to tests/runner/test_tpu_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
77
SchedulerConfig, SpeculativeConfig, VllmConfig)
88

9-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
9+
from tpu_inference.runner.tpu_runner import TPUModelRunner
1010

1111

1212
class TestTPUJaxRunner:
@@ -21,8 +21,8 @@ def setup_method(self):
2121
with patch('jax.devices', return_value=self.mock_devices), \
2222
patch('jax.make_mesh', return_value=self.mock_mesh), \
2323
patch('jax.random.key', return_value=self.mock_rng_key), \
24-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=MagicMock()), \
25-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh', return_value=self.mock_mesh):
24+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
25+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
2626

2727
model_config = ModelConfig(tokenizer_mode="auto",
2828
trust_remote_code=False,
@@ -114,9 +114,9 @@ def setup_method(self):
114114
with patch('jax.devices', return_value=self.mock_devices), \
115115
patch('jax.make_mesh', return_value=self.mock_mesh), \
116116
patch('jax.random.key', return_value=self.mock_rng_key), \
117-
patch('tpu_inference.runner.tpu_jax_runner.nnx.Rngs', return_value=self.mock_rng_key), \
118-
patch('tpu_inference.runner.tpu_jax_runner.get_model', return_value=self._model_get_model()), \
119-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh', return_value=self.mock_mesh):
117+
patch('tpu_inference.runner.tpu_runner.nnx.Rngs', return_value=self.mock_rng_key), \
118+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=self._model_get_model()), \
119+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
120120

121121
model_config = ModelConfig(tokenizer_mode="auto",
122122
trust_remote_code=False,

tests/runner/test_tpu_jax_runner_dp.py renamed to tests/runner/test_tpu_runner_dp.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
7+
from tpu_inference.runner.tpu_runner import TPUModelRunner
88

99

1010
class TestTPUJaxRunnerDPInputsLightweight:
@@ -76,11 +76,11 @@ def _create_mock_scheduler_output(self,
7676
mock_output.grammar_bitmask = None
7777
return mock_output
7878

79-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
80-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
81-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
79+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
80+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
81+
@patch('tpu_inference.runner.tpu_runner.device_array',
8282
side_effect=lambda mesh, tensors, **kwargs: tensors)
83-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
83+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
8484
def test_prepare_inputs_dp_basic_functionality(self,
8585
mock_sampling_metadata,
8686
mock_device_array,
@@ -136,7 +136,7 @@ def test_prepare_dp_input_metadata(self):
136136
scheduler_output = self._create_mock_scheduler_output(
137137
num_scheduled_tokens, assigned_dp_ranks)
138138

139-
with patch('tpu_inference.runner.tpu_jax_runner.runner_utils'
139+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
140140
) as mock_runner_utils:
141141
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 # Padded tokens per DP rank
142142

@@ -210,7 +210,7 @@ def test_prepare_dp_input_metadata_empty_rank(self):
210210
scheduler_output = self._create_mock_scheduler_output(
211211
num_scheduled_tokens, assigned_dp_ranks)
212212

213-
with patch('tpu_inference.runner.tpu_jax_runner.runner_utils'
213+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
214214
) as mock_runner_utils:
215215
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32
216216

@@ -287,7 +287,7 @@ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self):
287287
scheduler_output = self._create_mock_scheduler_output(
288288
num_scheduled_tokens, assigned_dp_ranks)
289289

290-
with patch('tpu_inference.runner.tpu_jax_runner.runner_utils'
290+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
291291
) as mock_runner_utils:
292292
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 8 if val <= 6 else 16
293293

@@ -316,11 +316,11 @@ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self):
316316
np.testing.assert_array_equal(logits_indices_selector,
317317
expected_positions)
318318

319-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
320-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
321-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
319+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
320+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
321+
@patch('tpu_inference.runner.tpu_runner.device_array',
322322
side_effect=lambda mesh, tensors, **kwargs: tensors)
323-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
323+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
324324
def test_prepare_inputs_dp_verify_content_balanced(self,
325325
mock_sampling_metadata,
326326
mock_device_array,
@@ -433,11 +433,11 @@ def mock_get_padded_token_len(paddings_list, val):
433433
assert len(logits_indices_selector) == 2
434434
assert np.array_equal(logits_indices_selector, np.array([0, 4]))
435435

436-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
437-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
438-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
436+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
437+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
438+
@patch('tpu_inference.runner.tpu_runner.device_array',
439439
side_effect=lambda mesh, tensors, **kwargs: tensors)
440-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
440+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
441441
def test_prepare_inputs_dp_verify_content_empty_rank(
442442
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
443443
mock_named_sharding):
@@ -558,11 +558,11 @@ def mock_get_padded_token_len(paddings_list, val):
558558
np.testing.assert_array_equal(logits_indices_selector,
559559
expected_selector)
560560

561-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
562-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
563-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
561+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
562+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
563+
@patch('tpu_inference.runner.tpu_runner.device_array',
564564
side_effect=lambda mesh, tensors, **kwargs: tensors)
565-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
565+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
566566
def test_prepare_async_token_substitution_indices_dp(
567567
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
568568
mock_named_sharding):
@@ -598,11 +598,11 @@ def test_prepare_async_token_substitution_indices_dp(
598598
assert token_in_tpu_cur_input_indices_dp[1] == [11]
599599
assert token_in_tpu_pre_next_tokens_indices_dp[1] == [2]
600600

601-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
602-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
603-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
601+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
602+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
603+
@patch('tpu_inference.runner.tpu_runner.device_array',
604604
side_effect=lambda mesh, tensors, **kwargs: tensors)
605-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
605+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
606606
def test_prepare_async_token_substitution_indices_dp_no_placeholders(
607607
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
608608
mock_named_sharding):
@@ -652,7 +652,7 @@ def test_apply_async_token_substitution_empty_indices(self):
652652
# Should return input_ids unchanged
653653
np.testing.assert_array_equal(result, input_ids)
654654

655-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
655+
@patch('tpu_inference.runner.tpu_runner.device_array',
656656
side_effect=lambda mesh, tensors, **kwargs: tensors)
657657
def test_apply_async_token_substitution_with_padding(
658658
self, mock_device_array):
@@ -733,11 +733,11 @@ def test_prepare_inputs_routing_to_non_dp(self):
733733
self.runner._prepare_inputs_non_dp.assert_called_once_with(
734734
scheduler_output)
735735

736-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
737-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
738-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
736+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
737+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
738+
@patch('tpu_inference.runner.tpu_runner.device_array',
739739
side_effect=lambda mesh, tensors, **kwargs: tensors)
740-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
740+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
741741
def test_prepare_inputs_dp_with_async_scheduling(self,
742742
mock_sampling_metadata,
743743
mock_device_array,
@@ -806,11 +806,11 @@ def mock_get_padded_token_len(paddings_list, val):
806806
# Verify async token substitution was called
807807
mock_prepare_async.assert_called_once()
808808

809-
@patch('tpu_inference.runner.tpu_jax_runner.NamedSharding')
810-
@patch('tpu_inference.runner.tpu_jax_runner.runner_utils')
811-
@patch('tpu_inference.runner.tpu_jax_runner.device_array',
809+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
810+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
811+
@patch('tpu_inference.runner.tpu_runner.device_array',
812812
side_effect=lambda mesh, tensors, **kwargs: tensors)
813-
@patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata')
813+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
814814
def test_prepare_inputs_dp_async_token_substitution_application(
815815
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
816816
mock_named_sharding):

tests/runner/test_tpu_jax_runner_mesh.py renamed to tests/runner/test_tpu_runner_mesh.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
7+
from tpu_inference.runner.tpu_runner import TPUModelRunner
88

99

1010
class TestTPUModelRunnerMeshInit:
@@ -54,8 +54,8 @@ def test_init_mesh_2d_model_without_device_order(self, runner_instance,
5454
mock_vllm_config):
5555
"""Test 2d mesh creation without enforced device order."""
5656
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
57-
patch('tpu_inference.runner.tpu_jax_runner.make_optimized_mesh') as mock_make_mesh, \
58-
patch('tpu_inference.runner.tpu_jax_runner.logger'):
57+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
58+
patch('tpu_inference.runner.tpu_runner.logger'):
5959

6060
mock_mesh = Mock()
6161
mock_make_mesh.return_value = mock_mesh
@@ -81,7 +81,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
8181

8282
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
8383
patch('jax.make_mesh') as mock_jax_mesh, \
84-
patch('tpu_inference.runner.tpu_jax_runner.logger'):
84+
patch('tpu_inference.runner.tpu_runner.logger'):
8585

8686
mock_mesh = Mock()
8787
mock_jax_mesh.return_value = mock_mesh
@@ -104,9 +104,9 @@ def test_init_mesh_new_model_single_slice(self, runner_instance,
104104
mock_vllm_config):
105105
"""Test new model mesh creation with single slice."""
106106
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
107-
patch('tpu_inference.runner.tpu_jax_runner.mesh_utils') as mock_mesh_utils, \
107+
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
108108
patch('jax.sharding.Mesh') as mock_jax_mesh, \
109-
patch('tpu_inference.runner.tpu_jax_runner.logger'):
109+
patch('tpu_inference.runner.tpu_runner.logger'):
110110

111111
mock_devices_array = Mock()
112112
mock_mesh_utils.create_device_mesh.return_value = mock_devices_array
@@ -135,9 +135,9 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
135135
"""Test new model mesh creation with multiple slices."""
136136
num_slices = 2
137137
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
138-
patch('tpu_inference.runner.tpu_jax_runner.mesh_utils') as mock_mesh_utils, \
138+
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
139139
patch('jax.sharding.Mesh') as mock_jax_mesh, \
140-
patch('tpu_inference.runner.tpu_jax_runner.logger'):
140+
patch('tpu_inference.runner.tpu_runner.logger'):
141141

142142
mock_devices_array = Mock()
143143
mock_mesh_utils.create_hybrid_device_mesh.return_value = mock_devices_array
@@ -174,7 +174,7 @@ def test_multi_slice_mesh_dp_inner_calculation(self, runner_instance,
174174
num_slices,
175175
expected_dp_inner):
176176
"""Test dp_inner calculation for various num_slices values."""
177-
with patch('tpu_inference.runner.tpu_jax_runner.mesh_utils'
177+
with patch('tpu_inference.runner.tpu_runner.mesh_utils'
178178
) as mock_mesh_utils:
179179
mock_mesh_utils.create_hybrid_device_mesh.return_value = Mock()
180180

0 commit comments

Comments
 (0)