Skip to content

Commit

Permalink
[misc] improve cloudpickle registration and tests (#10202)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 11, 2024
1 parent 20cf2f5 commit 73b9083
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
26 changes: 20 additions & 6 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class PPTestOptions(NamedTuple):
multi_node_only: bool
trust_remote_code: bool
tokenizer_mode: Optional[str]
load_format: Optional[str] = None
hf_overrides: Optional[str] = None


@dataclass
Expand All @@ -50,6 +52,8 @@ def detailed(
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
):
return PPTestSettings(
parallel_setups=[
Expand Down Expand Up @@ -78,7 +82,9 @@ def detailed(
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode),
tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
)

@staticmethod
Expand All @@ -90,6 +96,8 @@ def fast(
multi_node_only: bool = False,
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
):
return PPTestSettings(
parallel_setups=[
Expand All @@ -102,7 +110,9 @@ def fast(
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode),
tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
)

def iter_params(self, model_name: str):
Expand Down Expand Up @@ -161,9 +171,8 @@ def iter_params(self, model_name: str):
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
Expand Down Expand Up @@ -214,9 +223,9 @@ def iter_params(self, model_name: str):
# NOTE: You can update this on your local machine to run specific tests
TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Meta-Llama-3-8B",
"ibm/PowerLM-3b",
"microsoft/Phi-3-mini-4k-instruct",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
Expand All @@ -238,7 +247,8 @@ def _compare_tp(
method: Literal["generate", "encode"],
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode = test_options
multi_node_only, trust_remote_code, tokenizer_mode, \
load_format, hf_overrides = test_options

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
Expand Down Expand Up @@ -267,6 +277,10 @@ def _compare_tp(
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])

if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
Expand Down
4 changes: 0 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser, StoreBoolean

Expand Down Expand Up @@ -1013,8 +1011,6 @@ def create_engine_config(self) -> VllmConfig:
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False

maybe_register_config_serialize_by_value(self.trust_remote_code)

cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
Expand Down
51 changes: 30 additions & 21 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def get_config(

patch_rope_scaling(config)

if trust_remote_code:
maybe_register_config_serialize_by_value()

return config


Expand Down Expand Up @@ -389,33 +392,39 @@ def get_sentence_transformer_tokenizer_config(model: str,
return None


def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
def maybe_register_config_serialize_by_value() -> None:
"""Try to register HF model configuration class to serialize by value
With trust_remote_code, the config class is typically an instance of a
custom class imported from the HF modules cache. The class will not be
importable in spawned workers by default (and won't exist at all on
other nodes), which breaks serialization of the config.
If trust_remote_code is set, and the model's config file specifies an
`AutoConfig` class, then the config class is typically an instance of
a custom class imported from the HF modules cache.
Examples:
>>> from transformers import AutoConfig
>>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True)
>>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig
>>> import transformers_modules # error, not initialized
>>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True)
>>> import transformers_modules # success, initialized
>>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config
In the DeepSeek example, the config class is an instance of a custom
class that is not serializable by default. This class will not be
importable in spawned workers, and won't exist at all on
other nodes, which breaks serialization of the config.
In this function we tell the cloudpickle serialization library to pass
instances of these generated classes by value instead of by reference,
i.e. the class definition is serialized along with its data so that the
class module does not need to be importable on the receiving end. This
registration only works if the modules cache has already been
initialized.
class module does not need to be importable on the receiving end.
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
"""
if not trust_remote_code:
return

""" # noqa
try:
import transformers_modules
except ImportError:
logger.debug("Could not import transformers_modules used for remote"
" code. If remote code is not needed remove"
" `--trust-remote-code`.")
# the config does not need trust_remote_code
return

try:
Expand All @@ -428,19 +437,19 @@ class module does not need to be importable on the receiving end. This
ray.cloudpickle.register_pickle_by_value(transformers_modules)

# multiprocessing uses pickle to serialize arguments when using spawn
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
# Here we get pickle to use cloudpickle to serialize config objects
# that contain instances of the custom config class to avoid
# serialization problems if the generated module (and model) has a `.`
# in its name
import multiprocessing
import pickle

from vllm.config import ModelConfig
from vllm.config import VllmConfig

def _reduce_modelconfig(mc: ModelConfig):
return (pickle.loads, (cloudpickle.dumps(mc), ))
def _reduce_config(config: VllmConfig):
return (pickle.loads, (cloudpickle.dumps(config), ))

multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
multiprocessing.reducer.register(VllmConfig, _reduce_config)

except Exception as e:
logger.warning(
Expand Down

0 comments on commit 73b9083

Please sign in to comment.