diff --git a/aviary/backend/llm/initializers/hf_transformers/base.py b/aviary/backend/llm/initializers/hf_transformers/base.py index 8fa02128..d850aac7 100644 --- a/aviary/backend/llm/initializers/hf_transformers/base.py +++ b/aviary/backend/llm/initializers/hf_transformers/base.py @@ -88,22 +88,22 @@ def load_model(self, model_id: str) -> "PreTrainedModel": Args: model_id (str): Hugging Face model ID. """ - model_id_or_path = self._get_model_location_on_disk(model_id) from_pretrained_kwargs = self._get_model_from_pretrained_kwargs() - logger.info(f"Loading model {model_id_or_path}...") + logger.info(f"Loading model {model_id}...") try: model = AutoModelForCausalLM.from_pretrained( - model_id_or_path, **from_pretrained_kwargs + model_id, **from_pretrained_kwargs ) except OSError: - if model_id_or_path != model_id: + location = self._get_model_location_on_disk(model_id) + if model_id != location: logger.warning( - f"Couldn't load model from derived path {model_id_or_path}, " - f"trying to load from model_id {model_id}" + f"Couldn't load model {model_id}, " + f"trying to load from derived location {location}" ) model = AutoModelForCausalLM.from_pretrained( - model_id, **from_pretrained_kwargs + location, **from_pretrained_kwargs ) else: raise @@ -116,30 +116,26 @@ def load_tokenizer(self, tokenizer_id: str) -> "PreTrainedTokenizer": Args: tokenizer_id (str): Hugging Face tokenizer name. """ - tokenizer_id_or_path = self._get_model_location_on_disk(tokenizer_id) from_pretrained_kwargs = self._get_model_from_pretrained_kwargs() + trust_remote_code = from_pretrained_kwargs.get("trust_remote_code", False) # TODO make this more robust try: return AutoTokenizer.from_pretrained( - tokenizer_id_or_path, - padding_side="left", - trust_remote_code=from_pretrained_kwargs.get( - "trust_remote_code", False - ), + tokenizer_id, padding_side="left", trust_remote_code=trust_remote_code ) except Exception: - logger.warning( - f"Couldn't load tokenizer from derived path {tokenizer_id_or_path}, " - f"trying to load from model_id {tokenizer_id}" - ) - return AutoTokenizer.from_pretrained( - tokenizer_id, - padding_side="left", - trust_remote_code=from_pretrained_kwargs.get( - "trust_remote_code", False - ), - ) + location = self._get_model_location_on_disk(tokenizer_id) + if tokenizer_id != location: + logger.warning( + f"Couldn't load tokenizer {tokenizer_id}, " + f"trying to load from derived location {location}" + ) + return AutoTokenizer.from_pretrained( + location, padding_side="left", trust_remote_code=trust_remote_code + ) + else: + raise def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel": """Postprocess model. diff --git a/aviary/backend/llm/initializers/hf_transformers/deepspeed.py b/aviary/backend/llm/initializers/hf_transformers/deepspeed.py index e881fc73..e75b4ebd 100644 --- a/aviary/backend/llm/initializers/hf_transformers/deepspeed.py +++ b/aviary/backend/llm/initializers/hf_transformers/deepspeed.py @@ -49,6 +49,9 @@ def __init__( ds_inference_kwargs: Optional[Dict[str, Any]] = None, **from_pretrained_kwargs, ): + if dtype not in (torch.float16, torch.float32, torch.int8): + raise ValueError(f"dtype {dtype} is not supported by DeepSpeed.") + super().__init__( device=device, world_size=world_size, @@ -118,25 +121,23 @@ def _generate_checkpoint_json( return os.path.abspath(repo_root), os.path.abspath(checkpoints_json) def load_model(self, model_id: str) -> "PreTrainedModel": - model_id_or_path = self._get_model_location_on_disk(model_id) from_pretrained_kwargs = self._get_model_from_pretrained_kwargs() - logger.info(f"Loading model {model_id_or_path}...") + logger.info(f"Loading model {model_id}...") if self.use_meta_tensor: logger.info("Loading model using DeepSpeed meta tensor...") try: - config = AutoConfig.from_pretrained( - model_id_or_path, **from_pretrained_kwargs - ) + config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs) except OSError: - if model_id_or_path != model_id: + location = self._get_model_location_on_disk(model_id) + if model_id != location: logger.warning( - f"Couldn't load model from derived path {model_id_or_path}, " - f"trying to load from model_id {model_id}" + f"Couldn't load model {model_id}, " + f"trying to load from derived location {location}" ) config = AutoConfig.from_pretrained( - model_id, **from_pretrained_kwargs + location, **from_pretrained_kwargs ) else: raise @@ -144,25 +145,14 @@ def load_model(self, model_id: str) -> "PreTrainedModel": self._repo_root, self._checkpoints_json = self._generate_checkpoint_json( model_id ) + trust_remote_code = from_pretrained_kwargs.get("trust_remote_code", False) with deepspeed.OnDevice(dtype=torch.float16, device="meta"): - model = AutoModelForCausalLM.from_config(config) - else: - try: - model = AutoModelForCausalLM.from_pretrained( - model_id_or_path, **from_pretrained_kwargs + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code ) - except OSError: - if model_id_or_path != model_id: - logger.warning( - f"Couldn't load model from derived path {model_id_or_path}, " - f"trying to load from model_id {model_id}" - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, **from_pretrained_kwargs - ) - else: - raise + else: + model = super().load_model(model_id) model.eval() return model @@ -220,5 +210,5 @@ def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel": # Add attributes for compatibility with the pipeline model.use_kernel = self.use_kernel model.device = self.device - model = model.to(self.device) + # model = model.to(self.device) return model diff --git a/aviary/backend/llm/predictor.py b/aviary/backend/llm/predictor.py index 1d2552ab..40c21dec 100644 --- a/aviary/backend/llm/predictor.py +++ b/aviary/backend/llm/predictor.py @@ -36,6 +36,7 @@ def init_model( world_size: int, local_rank: int, max_batch_size: Optional[int] = None, + full_warmup: bool = False, ): """Initialize the model. @@ -44,6 +45,8 @@ def init_model( world_size (int): Number of GPUs. local_rank (int): Local rank of the current GPU. max_batch_size (Optional[int], optional): Maximum batch size. Defaults to None. + full_warmup (bool, optional): Whether to do a full warmup (min_new_tokens=max_new_tokens and + max out input length). Defaults to False. """ logger.info(f"Initializing model {llm_config.model_id}...") @@ -75,8 +78,9 @@ def init_model( # otherwise subsequent batches with more entries than the first batch # will raise CUDA errors if use_kernel=True. batch_size = max_batch_size or 1 - prompt = [WARMUP_PROMPT] * ( - int(llm_config.max_input_words / (len(WARMUP_PROMPT.split()) + 1)) + 1 + n_repeats = llm_config.max_input_words if full_warmup else 1 + prompt = [WARMUP_PROMPT] * max( + 1, (int(n_repeats / (len(WARMUP_PROMPT.split()) + 1))) ) prompt = " ".join(prompt) logger.info( @@ -84,7 +88,12 @@ def init_model( ) generate_kwargs = llm_config.generation.all_generate_kwargs.copy() if "max_new_tokens" in generate_kwargs: - generate_kwargs["min_new_tokens"] = generate_kwargs["max_new_tokens"] + if full_warmup: + generate_kwargs["min_new_tokens"] = generate_kwargs["max_new_tokens"] + else: + generate_kwargs["max_new_tokens"] = generate_kwargs.get( + "min_new_tokens", 16 + ) warmup_success = False while not warmup_success: try: @@ -178,6 +187,7 @@ def init_model( self.world_size, local_rank, max_batch_size=self.llm_config.generation.max_batch_size, + full_warmup=self.llm_config.initialization.full_warmup, ) def generate( diff --git a/aviary/backend/llm/utils.py b/aviary/backend/llm/utils.py index 31f8e96d..b84bce4a 100644 --- a/aviary/backend/llm/utils.py +++ b/aviary/backend/llm/utils.py @@ -38,7 +38,6 @@ def download_model( from transformers.utils.hub import TRANSFORMERS_CACHE s3_sync_args = s3_sync_args or [] - logger.info(f"Downloading {model_id} from {bucket_uri} to '{TRANSFORMERS_CACHE}'") path = os.path.expanduser( os.path.join(TRANSFORMERS_CACHE, f"models--{model_id.replace('/', '--')}") ) @@ -53,6 +52,9 @@ def download_model( ) with open(os.path.join(".", "hash"), "r") as f: f_hash = f.read().strip() + logger.info( + f"Downloading {model_id} from {bucket_uri} to {os.path.join(path, 'snapshots', f_hash)}" + ) subprocess.run(["mkdir", "-p", os.path.join(path, "snapshots", f_hash)]) subprocess.run(["mkdir", "-p", os.path.join(path, "refs")]) subprocess.run( diff --git a/aviary/backend/server/models.py b/aviary/backend/server/models.py index a484e8fa..74890553 100644 --- a/aviary/backend/server/models.py +++ b/aviary/backend/server/models.py @@ -208,7 +208,9 @@ def allowed_pipelines(self) -> Set[str]: class Transformers(Initializer, extra=Extra.forbid): use_bettertransformer: bool = False torch_compile: Optional[TorchCompile] = None - dtype: str = "float16" + dtype: Union[ + Literal["float16"], Literal["bfloat16"], Literal["float32"], Literal["int8"] + ] = "float16" from_pretrained_kwargs: Dict[str, Any] = {} @property @@ -229,6 +231,7 @@ def allowed_pipelines(self) -> Set[str]: class DeepSpeed(Transformers): type: Literal["DeepSpeed"] + dtype: Union[Literal["float16"], Literal["float32"], Literal["int8"]] = "float16" use_kernel: bool = False max_tokens: int = 1024 use_meta_tensor: bool = False @@ -289,6 +292,7 @@ class InitializationConfig(BaseModelExtended): s3_mirror_config: Optional[S3MirrorConfig] = None runtime_env: Optional[Dict[str, Any]] = None hf_model_id: Optional[str] = None + full_warmup: bool = False # For debugging purposes @root_validator def initializer_pipeline(cls, values): diff --git a/deploy/_internal/backend/cluster-env.yaml b/deploy/_internal/backend/cluster-env.yaml index 85fa9db1..af731813 100644 --- a/deploy/_internal/backend/cluster-env.yaml +++ b/deploy/_internal/backend/cluster-env.yaml @@ -9,7 +9,7 @@ env_vars: TORCH_HOME: /mnt/local_storage/data/cache/torch post_build_cmds: - |- - echo "dedup version 1. increment this to force a rebuild." + echo "dedup version 2. increment this to force a rebuild." pip uninstall -y torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric tensorflow pip install -i https://download.pytorch.org/whl/cu118 torch torchvision torchaudio pip install \ diff --git a/models/OpenAssistant--falcon-40b-sft-top1-560.yaml b/models/OpenAssistant--falcon-40b-sft-top1-560.yaml new file mode 100644 index 00000000..9be04488 --- /dev/null +++ b/models/OpenAssistant--falcon-40b-sft-top1-560.yaml @@ -0,0 +1,46 @@ +deployment_config: + autoscaling_config: + min_replicas: 2 + initial_replicas: 2 + max_replicas: 2 + target_num_ongoing_requests_per_replica: 1.0 + metrics_interval_s: 10.0 + look_back_period_s: 30.0 + smoothing_factor: 1.0 + downscale_delay_s: 300.0 + upscale_delay_s: 90.0 + ray_actor_options: + resources: + accelerator_type_cpu: 0.01 +model_config: + model_id: OpenAssistant/falcon-40b-sft-top1-560 + max_input_words: 800 + initialization: + s3_mirror_config: + bucket_uri: s3://large-dl-models-mirror/models--OpenAssistant--falcon-40b-sft-top1-560/main-safetensors/ + initializer: + type: DeepSpeed + dtype: float16 + from_pretrained_kwargs: + trust_remote_code: true + use_kernel: true + max_tokens: 1536 + pipeline: default + generation: + max_batch_size: 8 + generate_kwargs: + do_sample: true + max_new_tokens: 512 + min_new_tokens: 16 + temperature: 0.4 + top_p: 0.9 + repetition_penalty: 1.02 + return_token_type_ids: false + prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}<|endoftext|><|assistant|>" + stopping_sequences: ["<|prompter|>", "<|assistant|>", "<|endoftext|>"] +scaling_config: + num_workers: 2 + num_gpus_per_worker: 1 + num_cpus_per_worker: 8 + resources_per_worker: + accelerator_type_a100: 0.01 diff --git a/models/OpenAssistant--falcon-7b-sft-top1-696.yaml b/models/OpenAssistant--falcon-7b-sft-top1-696.yaml new file mode 100644 index 00000000..10753ab9 --- /dev/null +++ b/models/OpenAssistant--falcon-7b-sft-top1-696.yaml @@ -0,0 +1,47 @@ +deployment_config: + autoscaling_config: + min_replicas: 1 + initial_replicas: 1 + max_replicas: 8 + target_num_ongoing_requests_per_replica: 1.0 + metrics_interval_s: 10.0 + look_back_period_s: 30.0 + smoothing_factor: 1.0 + downscale_delay_s: 300.0 + upscale_delay_s: 90.0 + ray_actor_options: + resources: + accelerator_type_cpu: 0.01 +model_config: + model_id: OpenAssistant/falcon-7b-sft-top1-696 + max_input_words: 800 + initialization: + s3_mirror_config: + bucket_uri: s3://large-dl-models-mirror/models--OpenAssistant--falcon-7b-sft-top1-696/main-safetensors/ + initializer: + type: DeviceMap + dtype: bfloat16 + from_pretrained_kwargs: + trust_remote_code: true + torch_compile: + backend: inductor + mode: max-autotune + pipeline: default + generation: + max_batch_size: 4 + generate_kwargs: + do_sample: true + max_new_tokens: 512 + min_new_tokens: 16 + temperature: 0.4 + top_p: 0.9 + repetition_penalty: 1.02 + return_token_type_ids: false + prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}<|endoftext|><|assistant|>" + stopping_sequences: ["<|prompter|>", "<|assistant|>", "<|endoftext|>"] +scaling_config: + num_workers: 1 + num_gpus_per_worker: 1 + num_cpus_per_worker: 8 + resources_per_worker: + accelerator_type_a10: 0.01 diff --git a/models/OpenAssistant--oasst-sft-7-llama-30b-xor.yaml b/models/OpenAssistant--oasst-sft-7-llama-30b-xor.yaml index 57868050..108ac7c6 100644 --- a/models/OpenAssistant--oasst-sft-7-llama-30b-xor.yaml +++ b/models/OpenAssistant--oasst-sft-7-llama-30b-xor.yaml @@ -1,8 +1,8 @@ deployment_config: autoscaling_config: - min_replicas: 1 - initial_replicas: 1 - max_replicas: 4 + min_replicas: 2 + initial_replicas: 2 + max_replicas: 2 target_num_ongoing_requests_per_replica: 1.0 metrics_interval_s: 10.0 look_back_period_s: 30.0 @@ -32,7 +32,7 @@ model_config: max_tokens: 1536 pipeline: default generation: - max_batch_size: 4 + max_batch_size: 16 generate_kwargs: do_sample: true max_new_tokens: 512 @@ -46,8 +46,8 @@ model_config: prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}<|assistant|>" stopping_sequences: [2, 32002, 32004] scaling_config: - num_workers: 4 + num_workers: 2 num_gpus_per_worker: 1 num_cpus_per_worker: 4 resources_per_worker: - accelerator_type_a10: 0.01 + accelerator_type_a100: 0.01 diff --git a/models/mosaicml--mpt-7b-chat.yaml b/models/mosaicml--mpt-7b-chat.yaml index 49b98341..bc55a2ae 100644 --- a/models/mosaicml--mpt-7b-chat.yaml +++ b/models/mosaicml--mpt-7b-chat.yaml @@ -1,7 +1,7 @@ deployment_config: autoscaling_config: - min_replicas: 2 - initial_replicas: 2 + min_replicas: 1 + initial_replicas: 1 max_replicas: 8 target_num_ongoing_requests_per_replica: 1.0 metrics_interval_s: 10.0