|
2 | 2 | from copy import deepcopy
|
3 | 3 | from dataclasses import dataclass
|
4 | 4 | from typing import Dict, List, Optional, Tuple
|
| 5 | +from unittest.mock import patch |
5 | 6 |
|
6 | 7 | import pytest
|
7 | 8 | import torch
|
|
32 | 33 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
33 | 34 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
34 | 35 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35 |
| - ParallelLMHead, VocabParallelEmbedding) |
| 36 | + ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) |
36 | 37 | from vllm.model_executor.utils import set_random_seed
|
37 | 38 |
|
38 | 39 | from .utils import DummyLoRAManager
|
@@ -427,7 +428,8 @@ def _pretest():
|
427 | 428 | logits_processor = LogitsProcessor(
|
428 | 429 | vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
429 | 430 | lora_logits_processor = LogitsProcessorWithLoRA(
|
430 |
| - logits_processor, 1024, linear.weight.dtype, linear.weight.device) |
| 431 | + logits_processor, 1024, linear.weight.dtype, linear.weight.device, |
| 432 | + None) |
431 | 433 | lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
432 | 434 |
|
433 | 435 | return linear, logits_processor, lora_logits_processor
|
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
867 | 869 |
|
868 | 870 | torch.allclose(ref_q, actual_q)
|
869 | 871 | torch.allclose(ref_k, actual_k)
|
| 872 | + |
| 873 | + |
| 874 | +@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) |
| 875 | +@pytest.mark.parametrize("seed", list(range(256))) |
| 876 | +def test_vocab_parallel_embedding_indices(tp_size, seed): |
| 877 | + random.seed(seed) |
| 878 | + vocab_size = random.randint(4000, 64000) |
| 879 | + added_vocab_size = random.randint(0, 1024) |
| 880 | + org_vocab_size = vocab_size - added_vocab_size |
| 881 | + last_org_vocab_end_index = 0 |
| 882 | + last_added_vocab_end_index = org_vocab_size |
| 883 | + computed_vocab_size = 0 |
| 884 | + computed_org_vocab_size = 0 |
| 885 | + computed_added_vocab_size = 0 |
| 886 | + vocab_size_padded = -1 |
| 887 | + |
| 888 | + all_org_tokens = [] |
| 889 | + all_added_tokens = [] |
| 890 | + token_ids = [] |
| 891 | + |
| 892 | + for tp_rank in range(tp_size): |
| 893 | + with patch( |
| 894 | + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", |
| 895 | + return_value=tp_rank |
| 896 | + ), patch( |
| 897 | + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", |
| 898 | + return_value=tp_size): |
| 899 | + vocab_embedding = VocabParallelEmbedding( |
| 900 | + vocab_size, 1, org_num_embeddings=org_vocab_size) |
| 901 | + vocab_size_padded = vocab_embedding.num_embeddings_padded |
| 902 | + shard_indices = vocab_embedding.shard_indices |
| 903 | + # Assert that the ranges are contiguous |
| 904 | + assert shard_indices.org_vocab_start_index == last_org_vocab_end_index |
| 905 | + assert (shard_indices.added_vocab_start_index == |
| 906 | + last_added_vocab_end_index) |
| 907 | + |
| 908 | + # Ensure that we are not exceeding the vocab size |
| 909 | + computed_vocab_size += shard_indices.num_elements_padded |
| 910 | + computed_org_vocab_size += shard_indices.num_org_elements |
| 911 | + computed_added_vocab_size += shard_indices.num_added_elements |
| 912 | + |
| 913 | + # Ensure that the ranges are not overlapping |
| 914 | + all_org_tokens.extend( |
| 915 | + range(shard_indices.org_vocab_start_index, |
| 916 | + shard_indices.org_vocab_end_index)) |
| 917 | + all_added_tokens.extend( |
| 918 | + range(shard_indices.added_vocab_start_index, |
| 919 | + shard_indices.added_vocab_end_index)) |
| 920 | + |
| 921 | + token_ids.extend( |
| 922 | + range(shard_indices.org_vocab_start_index, |
| 923 | + shard_indices.org_vocab_end_index)) |
| 924 | + token_ids.extend([-1] * (shard_indices.num_org_elements_padded - |
| 925 | + shard_indices.num_org_elements)) |
| 926 | + token_ids.extend( |
| 927 | + range(shard_indices.added_vocab_start_index, |
| 928 | + shard_indices.added_vocab_end_index)) |
| 929 | + token_ids.extend([-1] * (shard_indices.num_added_elements_padded - |
| 930 | + shard_indices.num_added_elements)) |
| 931 | + |
| 932 | + last_org_vocab_end_index = shard_indices.org_vocab_end_index |
| 933 | + last_added_vocab_end_index = shard_indices.added_vocab_end_index |
| 934 | + |
| 935 | + assert computed_vocab_size == vocab_size_padded |
| 936 | + assert computed_org_vocab_size == org_vocab_size |
| 937 | + assert computed_added_vocab_size == added_vocab_size |
| 938 | + |
| 939 | + # Ensure that the ranges are not overlapping |
| 940 | + assert len(all_org_tokens) == len(set(all_org_tokens)) |
| 941 | + assert len(all_added_tokens) == len(set(all_added_tokens)) |
| 942 | + assert not set(all_org_tokens).intersection(set(all_added_tokens)) |
| 943 | + |
| 944 | + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) |
| 945 | + reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() |
| 946 | + assert reindex_mapping is not None or tp_size == 1 |
| 947 | + if reindex_mapping is not None: |
| 948 | + reindexed_token_ids = token_ids_tensor[reindex_mapping] |
| 949 | + expected = torch.tensor(list(range(0, vocab_size))) |
| 950 | + assert reindexed_token_ids[:vocab_size].equal(expected) |
| 951 | + assert torch.all(reindexed_token_ids[vocab_size:] == -1) |
| 952 | + |
| 953 | + |
| 954 | +def test_get_masked_input_and_mask(): |
| 955 | + x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) |
| 956 | + |
| 957 | + # base tp 1 case, no padding |
| 958 | + modified_x, _ = get_masked_input_and_mask(x, |
| 959 | + org_vocab_start_index=0, |
| 960 | + org_vocab_end_index=8, |
| 961 | + added_vocab_start_index=8, |
| 962 | + added_vocab_end_index=12, |
| 963 | + num_org_vocab_padding=0) |
| 964 | + assert torch.equal(x, modified_x) |
| 965 | + |
| 966 | + # tp 2 case, no padding |
| 967 | + modified_x_rank_0, _ = get_masked_input_and_mask(x, |
| 968 | + org_vocab_start_index=0, |
| 969 | + org_vocab_end_index=4, |
| 970 | + added_vocab_start_index=8, |
| 971 | + added_vocab_end_index=10, |
| 972 | + num_org_vocab_padding=0) |
| 973 | + modified_x_rank_1, _ = get_masked_input_and_mask( |
| 974 | + x, |
| 975 | + org_vocab_start_index=4, |
| 976 | + org_vocab_end_index=8, |
| 977 | + added_vocab_start_index=10, |
| 978 | + added_vocab_end_index=12, |
| 979 | + num_org_vocab_padding=0) |
| 980 | + assert torch.equal(modified_x_rank_0, |
| 981 | + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) |
| 982 | + assert torch.equal(modified_x_rank_1, |
| 983 | + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) |
| 984 | + |
| 985 | + # tp 4 case, no padding |
| 986 | + modified_x_rank_0, _ = get_masked_input_and_mask(x, |
| 987 | + org_vocab_start_index=0, |
| 988 | + org_vocab_end_index=2, |
| 989 | + added_vocab_start_index=8, |
| 990 | + added_vocab_end_index=9, |
| 991 | + num_org_vocab_padding=0) |
| 992 | + modified_x_rank_1, _ = get_masked_input_and_mask(x, |
| 993 | + org_vocab_start_index=2, |
| 994 | + org_vocab_end_index=4, |
| 995 | + added_vocab_start_index=9, |
| 996 | + added_vocab_end_index=10, |
| 997 | + num_org_vocab_padding=0) |
| 998 | + modified_x_rank_2, _ = get_masked_input_and_mask( |
| 999 | + x, |
| 1000 | + org_vocab_start_index=4, |
| 1001 | + org_vocab_end_index=6, |
| 1002 | + added_vocab_start_index=10, |
| 1003 | + added_vocab_end_index=11, |
| 1004 | + num_org_vocab_padding=0) |
| 1005 | + modified_x_rank_3, _ = get_masked_input_and_mask( |
| 1006 | + x, |
| 1007 | + org_vocab_start_index=6, |
| 1008 | + org_vocab_end_index=8, |
| 1009 | + added_vocab_start_index=11, |
| 1010 | + added_vocab_end_index=12, |
| 1011 | + num_org_vocab_padding=0) |
| 1012 | + assert torch.equal(modified_x_rank_0, |
| 1013 | + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) |
| 1014 | + assert torch.equal(modified_x_rank_1, |
| 1015 | + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) |
| 1016 | + assert torch.equal(modified_x_rank_2, |
| 1017 | + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) |
| 1018 | + assert torch.equal(modified_x_rank_3, |
| 1019 | + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) |
| 1020 | + |
| 1021 | + # base tp 1 case, with padding |
| 1022 | + modified_x, _ = get_masked_input_and_mask(x, |
| 1023 | + org_vocab_start_index=0, |
| 1024 | + org_vocab_end_index=8, |
| 1025 | + added_vocab_start_index=8, |
| 1026 | + added_vocab_end_index=12, |
| 1027 | + num_org_vocab_padding=2) |
| 1028 | + assert torch.equal(modified_x, |
| 1029 | + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) |
| 1030 | + |
| 1031 | + # tp 2 case, with padding |
| 1032 | + modified_x_rank_0, _ = get_masked_input_and_mask(x, |
| 1033 | + org_vocab_start_index=0, |
| 1034 | + org_vocab_end_index=4, |
| 1035 | + added_vocab_start_index=8, |
| 1036 | + added_vocab_end_index=10, |
| 1037 | + num_org_vocab_padding=2) |
| 1038 | + modified_x_rank_1, _ = get_masked_input_and_mask( |
| 1039 | + x, |
| 1040 | + org_vocab_start_index=4, |
| 1041 | + org_vocab_end_index=8, |
| 1042 | + added_vocab_start_index=10, |
| 1043 | + added_vocab_end_index=12, |
| 1044 | + num_org_vocab_padding=2) |
| 1045 | + assert torch.equal(modified_x_rank_0, |
| 1046 | + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) |
| 1047 | + assert torch.equal(modified_x_rank_1, |
| 1048 | + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) |
| 1049 | + |
| 1050 | + # tp 4 case, with padding |
| 1051 | + modified_x_rank_0, _ = get_masked_input_and_mask(x, |
| 1052 | + org_vocab_start_index=0, |
| 1053 | + org_vocab_end_index=2, |
| 1054 | + added_vocab_start_index=8, |
| 1055 | + added_vocab_end_index=9, |
| 1056 | + num_org_vocab_padding=2) |
| 1057 | + modified_x_rank_1, _ = get_masked_input_and_mask(x, |
| 1058 | + org_vocab_start_index=2, |
| 1059 | + org_vocab_end_index=4, |
| 1060 | + added_vocab_start_index=9, |
| 1061 | + added_vocab_end_index=10, |
| 1062 | + num_org_vocab_padding=2) |
| 1063 | + modified_x_rank_2, _ = get_masked_input_and_mask( |
| 1064 | + x, |
| 1065 | + org_vocab_start_index=4, |
| 1066 | + org_vocab_end_index=6, |
| 1067 | + added_vocab_start_index=10, |
| 1068 | + added_vocab_end_index=11, |
| 1069 | + num_org_vocab_padding=2) |
| 1070 | + modified_x_rank_3, _ = get_masked_input_and_mask( |
| 1071 | + x, |
| 1072 | + org_vocab_start_index=6, |
| 1073 | + org_vocab_end_index=8, |
| 1074 | + added_vocab_start_index=11, |
| 1075 | + added_vocab_end_index=12, |
| 1076 | + num_org_vocab_padding=2) |
| 1077 | + assert torch.equal(modified_x_rank_0, |
| 1078 | + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) |
| 1079 | + assert torch.equal(modified_x_rank_1, |
| 1080 | + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) |
| 1081 | + assert torch.equal(modified_x_rank_2, |
| 1082 | + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) |
| 1083 | + assert torch.equal(modified_x_rank_3, |
| 1084 | + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) |
0 commit comments