Skip to content

Commit e9a705b

Browse files
committed
in the middle of refactoring
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent fde81d3 commit e9a705b

File tree

1 file changed

+79
-75
lines changed

1 file changed

+79
-75
lines changed

tests/lora/test_layers.py

Lines changed: 79 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ def create_random_inputs(
202202
@pytest.mark.parametrize("repeats", [2])
203203
@pytest.mark.parametrize("stage", [True, False])
204204
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
205+
set_random_seed(6)
206+
205207
max_loras = 9
206208
max_lora_rank = 8
207209
lora_config = LoRAConfig(
@@ -213,24 +215,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
213215
vllm_config = dist_init
214216
vllm_config.lora_config = lora_config
215217

216-
axis_names = ("data", "model")
217-
devices = jax.devices()
218-
mesh_shape = (1, len(devices))
219-
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
220-
221-
set_random_seed(6)
222-
218+
mesh = _create_mesh()
223219
linear, lora_linear = _create_column_parallel_packed_layer(
224220
repeats, vllm_config, mesh)
225-
with torchax.default_env():
226-
# lora_linear.weight has type torchax.tensor.Tensor
227-
# BaseLinearLayerWithLoRA.weight property guarantees this.
228-
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
229-
# So the below check will fail.
230-
if len(devices) == 1:
231-
assert torch.equal(linear.weight.data,
232-
lora_linear.weight.to('cpu'))
221+
_verify_lora_linear_layer(linear, lora_linear)
233222

223+
# Create a punica wrapper and associate it with the lora linear layer.
234224
max_num_batched_tokens = 8192
235225
max_batches = 256
236226
with torchax.default_env():
@@ -251,6 +241,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
251241
repeats=repeats,
252242
)
253243

244+
# Create inputs and lora mappings.
254245
# inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
255246
# index_mapping: list[int]
256247
# prompt_mapping: list[int]
@@ -261,35 +252,14 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
261252
input_range=(0, 1),
262253
input_type=torch.float16,
263254
device='cpu')
264-
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
265255

266-
with torchax.default_env():
267-
# Here we move the metadata from cpu to tpu.
268-
punica_wrapper.update_metadata(
269-
lora_mapping,
270-
index_to_id,
271-
max_loras,
272-
vocab_size=512,
273-
extra_vocab_size=lora_config.lora_extra_vocab_size,
274-
)
275-
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
276-
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
277-
assert isinstance(
278-
jax_view(punica_wrapper._lora_indices_per_batch).sharding,
279-
jax.sharding.SingleDeviceSharding
280-
), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
256+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
257+
prompt_mapping, stage, index_to_id,
258+
lora_config)
281259

282-
jax_inputs = []
283-
with torchax.default_env():
284-
for input in inputs:
285-
# without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
286-
# without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
287-
jax_input = torch_view(t2j(input))
288-
jax_input.apply_jax_(jax.device_put,
289-
NamedSharding(mesh, P(None, None)))
290-
jax_inputs.append(jax_input)
291260
with torchax.default_env():
292-
lora_result = lora_linear(torch.cat(jax_inputs))[0]
261+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
262+
actual_result = lora_linear(torchax_inputs)[0]
293263

294264
expected_results: list[torch.Tensor] = []
295265
for input_, lora_id in zip(inputs, prompt_mapping):
@@ -303,19 +273,19 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
303273
expected_results.append(result)
304274
expected_result = torch.cat(expected_results)
305275

306-
rtol, atol = TOLERANCES[lora_result.dtype]
276+
rtol, atol = TOLERANCES[actual_result.dtype]
307277
with torchax.default_env():
308-
lora_result_cpu = lora_result.to('cpu')
309-
torch.testing.assert_close(lora_result_cpu,
278+
actual_result_cpu = actual_result.to('cpu')
279+
torch.testing.assert_close(actual_result_cpu,
310280
expected_result,
311281
rtol=rtol,
312282
atol=atol)
313-
print(
314-
f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}'
315-
)
316-
print(
317-
f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}'
318-
)
283+
# print(
284+
# f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu))}'
285+
# )
286+
# print(
287+
# f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu))}'
288+
# )
319289

320290
# Check that resetting the lora weights succeeds
321291
# Here we set all lora weight to be empty.
@@ -329,41 +299,75 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
329299
input_range=(0, 1),
330300
input_type=torch.float16,
331301
device='cpu')
332-
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
333302

334-
with torchax.default_env():
335-
punica_wrapper.update_metadata(
336-
lora_mapping,
337-
index_to_id,
338-
max_loras,
339-
512,
340-
lora_config.lora_extra_vocab_size,
341-
)
303+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
304+
prompt_mapping, stage, index_to_id,
305+
lora_config)
342306

343-
jax_inputs = []
344-
with torchax.default_env():
345-
for input in inputs:
346-
jax_input = torch_view(t2j(input))
347-
jax_input.apply_jax_(jax.device_put,
348-
NamedSharding(mesh, P(None, None)))
349-
jax_inputs.append(jax_input)
350307
with torchax.default_env():
351-
lora_result = lora_linear(torch.cat(jax_inputs))[0]
308+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
309+
actual_result = lora_linear(torchax_inputs)[0]
352310
expected_result = linear(torch.cat(inputs))[0]
353311

354-
rtol, atol = TOLERANCES[lora_result.dtype]
312+
rtol, atol = TOLERANCES[actual_result.dtype]
355313
with torchax.default_env():
356-
lora_result_cpu = lora_result.to('cpu')
357-
torch.testing.assert_close(lora_result_cpu,
314+
actual_result_cpu = actual_result.to('cpu')
315+
torch.testing.assert_close(actual_result_cpu,
358316
expected_result,
359317
rtol=rtol,
360318
atol=atol)
361-
print(
362-
f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}'
363-
)
364-
print(
365-
f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}'
319+
320+
321+
def _create_mesh():
322+
axis_names = ("data", "model")
323+
devices = jax.devices()
324+
mesh_shape = (1, len(devices))
325+
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
326+
return mesh
327+
328+
329+
def _verify_lora_linear_layer(linear, lora_linear):
330+
with torchax.default_env():
331+
# lora_linear.weight has type torchax.tensor.Tensor
332+
# BaseLinearLayerWithLoRA.weight property guarantees this.
333+
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
334+
# So the below check will fail.
335+
if len(jax.devices()) == 1:
336+
assert torch.equal(linear.weight.data,
337+
lora_linear.weight.to('cpu'))
338+
339+
340+
def _shard_and_move_inputs_to_tpu(inputs, mesh):
341+
processed_inputs = []
342+
for input in inputs:
343+
# without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
344+
# without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
345+
jax_input = torch_view(t2j(input))
346+
jax_input.apply_jax_(jax.device_put,
347+
NamedSharding(mesh, P(None, None)))
348+
processed_inputs.append(jax_input)
349+
return torch.cat(processed_inputs)
350+
351+
352+
def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
353+
prompt_mapping, stage, index_to_id,
354+
lora_config):
355+
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
356+
with torchax.default_env():
357+
# Here we move the metadata from cpu to tpu.
358+
punica_wrapper.update_metadata(
359+
lora_mapping,
360+
index_to_id,
361+
lora_config.max_loras,
362+
vocab_size=512,
363+
extra_vocab_size=lora_config.lora_extra_vocab_size,
366364
)
365+
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
366+
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
367+
assert isinstance(
368+
jax_view(punica_wrapper._lora_indices_per_batch).sharding,
369+
jax.sharding.SingleDeviceSharding
370+
), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
367371

368372

369373
def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):

0 commit comments

Comments
 (0)