|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
6 | 6 |
|
7 | | -from tpu_inference.runner.tpu_jax_runner import TPUModelRunner |
| 7 | +from tpu_inference.runner.tpu_runner import TPUModelRunner |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class TestTPUJaxRunnerDPInputsLightweight: |
@@ -76,11 +76,11 @@ def _create_mock_scheduler_output(self, |
76 | 76 | mock_output.grammar_bitmask = None |
77 | 77 | return mock_output |
78 | 78 |
|
79 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
80 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
81 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 79 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 80 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 81 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
82 | 82 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
83 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 83 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
84 | 84 | def test_prepare_inputs_dp_basic_functionality(self, |
85 | 85 | mock_sampling_metadata, |
86 | 86 | mock_device_array, |
@@ -136,7 +136,7 @@ def test_prepare_dp_input_metadata(self): |
136 | 136 | scheduler_output = self._create_mock_scheduler_output( |
137 | 137 | num_scheduled_tokens, assigned_dp_ranks) |
138 | 138 |
|
139 | | - with patch('tpu_inference.runner.tpu_jax_runner.runner_utils' |
| 139 | + with patch('tpu_inference.runner.tpu_runner.runner_utils' |
140 | 140 | ) as mock_runner_utils: |
141 | 141 | mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 # Padded tokens per DP rank |
142 | 142 |
|
@@ -210,7 +210,7 @@ def test_prepare_dp_input_metadata_empty_rank(self): |
210 | 210 | scheduler_output = self._create_mock_scheduler_output( |
211 | 211 | num_scheduled_tokens, assigned_dp_ranks) |
212 | 212 |
|
213 | | - with patch('tpu_inference.runner.tpu_jax_runner.runner_utils' |
| 213 | + with patch('tpu_inference.runner.tpu_runner.runner_utils' |
214 | 214 | ) as mock_runner_utils: |
215 | 215 | mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 |
216 | 216 |
|
@@ -287,7 +287,7 @@ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self): |
287 | 287 | scheduler_output = self._create_mock_scheduler_output( |
288 | 288 | num_scheduled_tokens, assigned_dp_ranks) |
289 | 289 |
|
290 | | - with patch('tpu_inference.runner.tpu_jax_runner.runner_utils' |
| 290 | + with patch('tpu_inference.runner.tpu_runner.runner_utils' |
291 | 291 | ) as mock_runner_utils: |
292 | 292 | mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 8 if val <= 6 else 16 |
293 | 293 |
|
@@ -316,11 +316,11 @@ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self): |
316 | 316 | np.testing.assert_array_equal(logits_indices_selector, |
317 | 317 | expected_positions) |
318 | 318 |
|
319 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
320 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
321 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 319 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 320 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 321 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
322 | 322 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
323 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 323 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
324 | 324 | def test_prepare_inputs_dp_verify_content_balanced(self, |
325 | 325 | mock_sampling_metadata, |
326 | 326 | mock_device_array, |
@@ -433,11 +433,11 @@ def mock_get_padded_token_len(paddings_list, val): |
433 | 433 | assert len(logits_indices_selector) == 2 |
434 | 434 | assert np.array_equal(logits_indices_selector, np.array([0, 4])) |
435 | 435 |
|
436 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
437 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
438 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 436 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 437 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 438 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
439 | 439 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
440 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 440 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
441 | 441 | def test_prepare_inputs_dp_verify_content_empty_rank( |
442 | 442 | self, mock_sampling_metadata, mock_device_array, mock_runner_utils, |
443 | 443 | mock_named_sharding): |
@@ -558,11 +558,11 @@ def mock_get_padded_token_len(paddings_list, val): |
558 | 558 | np.testing.assert_array_equal(logits_indices_selector, |
559 | 559 | expected_selector) |
560 | 560 |
|
561 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
562 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
563 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 561 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 562 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 563 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
564 | 564 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
565 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 565 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
566 | 566 | def test_prepare_async_token_substitution_indices_dp( |
567 | 567 | self, mock_sampling_metadata, mock_device_array, mock_runner_utils, |
568 | 568 | mock_named_sharding): |
@@ -598,11 +598,11 @@ def test_prepare_async_token_substitution_indices_dp( |
598 | 598 | assert token_in_tpu_cur_input_indices_dp[1] == [11] |
599 | 599 | assert token_in_tpu_pre_next_tokens_indices_dp[1] == [2] |
600 | 600 |
|
601 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
602 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
603 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 601 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 602 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 603 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
604 | 604 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
605 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 605 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
606 | 606 | def test_prepare_async_token_substitution_indices_dp_no_placeholders( |
607 | 607 | self, mock_sampling_metadata, mock_device_array, mock_runner_utils, |
608 | 608 | mock_named_sharding): |
@@ -652,7 +652,7 @@ def test_apply_async_token_substitution_empty_indices(self): |
652 | 652 | # Should return input_ids unchanged |
653 | 653 | np.testing.assert_array_equal(result, input_ids) |
654 | 654 |
|
655 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 655 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
656 | 656 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
657 | 657 | def test_apply_async_token_substitution_with_padding( |
658 | 658 | self, mock_device_array): |
@@ -733,11 +733,11 @@ def test_prepare_inputs_routing_to_non_dp(self): |
733 | 733 | self.runner._prepare_inputs_non_dp.assert_called_once_with( |
734 | 734 | scheduler_output) |
735 | 735 |
|
736 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
737 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
738 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 736 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 737 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 738 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
739 | 739 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
740 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 740 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
741 | 741 | def test_prepare_inputs_dp_with_async_scheduling(self, |
742 | 742 | mock_sampling_metadata, |
743 | 743 | mock_device_array, |
@@ -806,11 +806,11 @@ def mock_get_padded_token_len(paddings_list, val): |
806 | 806 | # Verify async token substitution was called |
807 | 807 | mock_prepare_async.assert_called_once() |
808 | 808 |
|
809 | | - @patch('tpu_inference.runner.tpu_jax_runner.NamedSharding') |
810 | | - @patch('tpu_inference.runner.tpu_jax_runner.runner_utils') |
811 | | - @patch('tpu_inference.runner.tpu_jax_runner.device_array', |
| 809 | + @patch('tpu_inference.runner.tpu_runner.NamedSharding') |
| 810 | + @patch('tpu_inference.runner.tpu_runner.runner_utils') |
| 811 | + @patch('tpu_inference.runner.tpu_runner.device_array', |
812 | 812 | side_effect=lambda mesh, tensors, **kwargs: tensors) |
813 | | - @patch('tpu_inference.runner.tpu_jax_runner.TPUSupportedSamplingMetadata') |
| 813 | + @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata') |
814 | 814 | def test_prepare_inputs_dp_async_token_substitution_application( |
815 | 815 | self, mock_sampling_metadata, mock_device_array, mock_runner_utils, |
816 | 816 | mock_named_sharding): |
|
0 commit comments