Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Add Falcon (#113)
Browse files Browse the repository at this point in the history
Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored Jun 7, 2023
1 parent 9846bea commit 4c6b7ce
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 64 deletions.
44 changes: 20 additions & 24 deletions aviary/backend/llm/initializers/hf_transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,22 @@ def load_model(self, model_id: str) -> "PreTrainedModel":
Args:
model_id (str): Hugging Face model ID.
"""
model_id_or_path = self._get_model_location_on_disk(model_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()

logger.info(f"Loading model {model_id_or_path}...")
logger.info(f"Loading model {model_id}...")
try:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
model_id, **from_pretrained_kwargs
)
except OSError:
if model_id_or_path != model_id:
location = self._get_model_location_on_disk(model_id)
if model_id != location:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
f"Couldn't load model {model_id}, "
f"trying to load from derived location {location}"
)
model = AutoModelForCausalLM.from_pretrained(
model_id, **from_pretrained_kwargs
location, **from_pretrained_kwargs
)
else:
raise
Expand All @@ -116,30 +116,26 @@ def load_tokenizer(self, tokenizer_id: str) -> "PreTrainedTokenizer":
Args:
tokenizer_id (str): Hugging Face tokenizer name.
"""
tokenizer_id_or_path = self._get_model_location_on_disk(tokenizer_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()
trust_remote_code = from_pretrained_kwargs.get("trust_remote_code", False)

# TODO make this more robust
try:
return AutoTokenizer.from_pretrained(
tokenizer_id_or_path,
padding_side="left",
trust_remote_code=from_pretrained_kwargs.get(
"trust_remote_code", False
),
tokenizer_id, padding_side="left", trust_remote_code=trust_remote_code
)
except Exception:
logger.warning(
f"Couldn't load tokenizer from derived path {tokenizer_id_or_path}, "
f"trying to load from model_id {tokenizer_id}"
)
return AutoTokenizer.from_pretrained(
tokenizer_id,
padding_side="left",
trust_remote_code=from_pretrained_kwargs.get(
"trust_remote_code", False
),
)
location = self._get_model_location_on_disk(tokenizer_id)
if tokenizer_id != location:
logger.warning(
f"Couldn't load tokenizer {tokenizer_id}, "
f"trying to load from derived location {location}"
)
return AutoTokenizer.from_pretrained(
location, padding_side="left", trust_remote_code=trust_remote_code
)
else:
raise

def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel":
"""Postprocess model.
Expand Down
42 changes: 16 additions & 26 deletions aviary/backend/llm/initializers/hf_transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(
ds_inference_kwargs: Optional[Dict[str, Any]] = None,
**from_pretrained_kwargs,
):
if dtype not in (torch.float16, torch.float32, torch.int8):
raise ValueError(f"dtype {dtype} is not supported by DeepSpeed.")

super().__init__(
device=device,
world_size=world_size,
Expand Down Expand Up @@ -118,51 +121,38 @@ def _generate_checkpoint_json(
return os.path.abspath(repo_root), os.path.abspath(checkpoints_json)

def load_model(self, model_id: str) -> "PreTrainedModel":
model_id_or_path = self._get_model_location_on_disk(model_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()

logger.info(f"Loading model {model_id_or_path}...")
logger.info(f"Loading model {model_id}...")
if self.use_meta_tensor:
logger.info("Loading model using DeepSpeed meta tensor...")

try:
config = AutoConfig.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
)
config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs)
except OSError:
if model_id_or_path != model_id:
location = self._get_model_location_on_disk(model_id)
if model_id != location:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
f"Couldn't load model {model_id}, "
f"trying to load from derived location {location}"
)
config = AutoConfig.from_pretrained(
model_id, **from_pretrained_kwargs
location, **from_pretrained_kwargs
)
else:
raise

self._repo_root, self._checkpoints_json = self._generate_checkpoint_json(
model_id
)
trust_remote_code = from_pretrained_kwargs.get("trust_remote_code", False)

with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
model = AutoModelForCausalLM.from_config(config)
else:
try:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
except OSError:
if model_id_or_path != model_id:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
)
model = AutoModelForCausalLM.from_pretrained(
model_id, **from_pretrained_kwargs
)
else:
raise
else:
model = super().load_model(model_id)
model.eval()
return model

Expand Down Expand Up @@ -220,5 +210,5 @@ def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel":
# Add attributes for compatibility with the pipeline
model.use_kernel = self.use_kernel
model.device = self.device
model = model.to(self.device)
# model = model.to(self.device)
return model
16 changes: 13 additions & 3 deletions aviary/backend/llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def init_model(
world_size: int,
local_rank: int,
max_batch_size: Optional[int] = None,
full_warmup: bool = False,
):
"""Initialize the model.
Expand All @@ -44,6 +45,8 @@ def init_model(
world_size (int): Number of GPUs.
local_rank (int): Local rank of the current GPU.
max_batch_size (Optional[int], optional): Maximum batch size. Defaults to None.
full_warmup (bool, optional): Whether to do a full warmup (min_new_tokens=max_new_tokens and
max out input length). Defaults to False.
"""
logger.info(f"Initializing model {llm_config.model_id}...")

Expand Down Expand Up @@ -75,16 +78,22 @@ def init_model(
# otherwise subsequent batches with more entries than the first batch
# will raise CUDA errors if use_kernel=True.
batch_size = max_batch_size or 1
prompt = [WARMUP_PROMPT] * (
int(llm_config.max_input_words / (len(WARMUP_PROMPT.split()) + 1)) + 1
n_repeats = llm_config.max_input_words if full_warmup else 1
prompt = [WARMUP_PROMPT] * max(
1, (int(n_repeats / (len(WARMUP_PROMPT.split()) + 1)))
)
prompt = " ".join(prompt)
logger.info(
f"Model {llm_config.model_id} is warming up, input len {len(prompt)}..."
)
generate_kwargs = llm_config.generation.all_generate_kwargs.copy()
if "max_new_tokens" in generate_kwargs:
generate_kwargs["min_new_tokens"] = generate_kwargs["max_new_tokens"]
if full_warmup:
generate_kwargs["min_new_tokens"] = generate_kwargs["max_new_tokens"]
else:
generate_kwargs["max_new_tokens"] = generate_kwargs.get(
"min_new_tokens", 16
)
warmup_success = False
while not warmup_success:
try:
Expand Down Expand Up @@ -178,6 +187,7 @@ def init_model(
self.world_size,
local_rank,
max_batch_size=self.llm_config.generation.max_batch_size,
full_warmup=self.llm_config.initialization.full_warmup,
)

def generate(
Expand Down
4 changes: 3 additions & 1 deletion aviary/backend/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def download_model(
from transformers.utils.hub import TRANSFORMERS_CACHE

s3_sync_args = s3_sync_args or []
logger.info(f"Downloading {model_id} from {bucket_uri} to '{TRANSFORMERS_CACHE}'")
path = os.path.expanduser(
os.path.join(TRANSFORMERS_CACHE, f"models--{model_id.replace('/', '--')}")
)
Expand All @@ -53,6 +52,9 @@ def download_model(
)
with open(os.path.join(".", "hash"), "r") as f:
f_hash = f.read().strip()
logger.info(
f"Downloading {model_id} from {bucket_uri} to {os.path.join(path, 'snapshots', f_hash)}"
)
subprocess.run(["mkdir", "-p", os.path.join(path, "snapshots", f_hash)])
subprocess.run(["mkdir", "-p", os.path.join(path, "refs")])
subprocess.run(
Expand Down
6 changes: 5 additions & 1 deletion aviary/backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def allowed_pipelines(self) -> Set[str]:
class Transformers(Initializer, extra=Extra.forbid):
use_bettertransformer: bool = False
torch_compile: Optional[TorchCompile] = None
dtype: str = "float16"
dtype: Union[
Literal["float16"], Literal["bfloat16"], Literal["float32"], Literal["int8"]
] = "float16"
from_pretrained_kwargs: Dict[str, Any] = {}

@property
Expand All @@ -229,6 +231,7 @@ def allowed_pipelines(self) -> Set[str]:

class DeepSpeed(Transformers):
type: Literal["DeepSpeed"]
dtype: Union[Literal["float16"], Literal["float32"], Literal["int8"]] = "float16"
use_kernel: bool = False
max_tokens: int = 1024
use_meta_tensor: bool = False
Expand Down Expand Up @@ -289,6 +292,7 @@ class InitializationConfig(BaseModelExtended):
s3_mirror_config: Optional[S3MirrorConfig] = None
runtime_env: Optional[Dict[str, Any]] = None
hf_model_id: Optional[str] = None
full_warmup: bool = False # For debugging purposes

@root_validator
def initializer_pipeline(cls, values):
Expand Down
2 changes: 1 addition & 1 deletion deploy/_internal/backend/cluster-env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ env_vars:
TORCH_HOME: /mnt/local_storage/data/cache/torch
post_build_cmds:
- |-
echo "dedup version 1. increment this to force a rebuild."
echo "dedup version 2. increment this to force a rebuild."
pip uninstall -y torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric tensorflow
pip install -i https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
pip install \
Expand Down
46 changes: 46 additions & 0 deletions models/OpenAssistant--falcon-40b-sft-top1-560.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
deployment_config:
autoscaling_config:
min_replicas: 2
initial_replicas: 2
max_replicas: 2
target_num_ongoing_requests_per_replica: 1.0
metrics_interval_s: 10.0
look_back_period_s: 30.0
smoothing_factor: 1.0
downscale_delay_s: 300.0
upscale_delay_s: 90.0
ray_actor_options:
resources:
accelerator_type_cpu: 0.01
model_config:
model_id: OpenAssistant/falcon-40b-sft-top1-560
max_input_words: 800
initialization:
s3_mirror_config:
bucket_uri: s3://large-dl-models-mirror/models--OpenAssistant--falcon-40b-sft-top1-560/main-safetensors/
initializer:
type: DeepSpeed
dtype: float16
from_pretrained_kwargs:
trust_remote_code: true
use_kernel: true
max_tokens: 1536
pipeline: default
generation:
max_batch_size: 8
generate_kwargs:
do_sample: true
max_new_tokens: 512
min_new_tokens: 16
temperature: 0.4
top_p: 0.9
repetition_penalty: 1.02
return_token_type_ids: false
prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}<|endoftext|><|assistant|>"
stopping_sequences: ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
scaling_config:
num_workers: 2
num_gpus_per_worker: 1
num_cpus_per_worker: 8
resources_per_worker:
accelerator_type_a100: 0.01
47 changes: 47 additions & 0 deletions models/OpenAssistant--falcon-7b-sft-top1-696.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
deployment_config:
autoscaling_config:
min_replicas: 1
initial_replicas: 1
max_replicas: 8
target_num_ongoing_requests_per_replica: 1.0
metrics_interval_s: 10.0
look_back_period_s: 30.0
smoothing_factor: 1.0
downscale_delay_s: 300.0
upscale_delay_s: 90.0
ray_actor_options:
resources:
accelerator_type_cpu: 0.01
model_config:
model_id: OpenAssistant/falcon-7b-sft-top1-696
max_input_words: 800
initialization:
s3_mirror_config:
bucket_uri: s3://large-dl-models-mirror/models--OpenAssistant--falcon-7b-sft-top1-696/main-safetensors/
initializer:
type: DeviceMap
dtype: bfloat16
from_pretrained_kwargs:
trust_remote_code: true
torch_compile:
backend: inductor
mode: max-autotune
pipeline: default
generation:
max_batch_size: 4
generate_kwargs:
do_sample: true
max_new_tokens: 512
min_new_tokens: 16
temperature: 0.4
top_p: 0.9
repetition_penalty: 1.02
return_token_type_ids: false
prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}<|endoftext|><|assistant|>"
stopping_sequences: ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
scaling_config:
num_workers: 1
num_gpus_per_worker: 1
num_cpus_per_worker: 8
resources_per_worker:
accelerator_type_a10: 0.01
12 changes: 6 additions & 6 deletions models/OpenAssistant--oasst-sft-7-llama-30b-xor.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
deployment_config:
autoscaling_config:
min_replicas: 1
initial_replicas: 1
max_replicas: 4
min_replicas: 2
initial_replicas: 2
max_replicas: 2
target_num_ongoing_requests_per_replica: 1.0
metrics_interval_s: 10.0
look_back_period_s: 30.0
Expand Down Expand Up @@ -32,7 +32,7 @@ model_config:
max_tokens: 1536
pipeline: default
generation:
max_batch_size: 4
max_batch_size: 16
generate_kwargs:
do_sample: true
max_new_tokens: 512
Expand All @@ -46,8 +46,8 @@ model_config:
prompt_format: "<|prefix_begin|>Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.<|prefix_end|><|prompter|>{instruction}</s><|assistant|>"
stopping_sequences: [2, 32002, 32004]
scaling_config:
num_workers: 4
num_workers: 2
num_gpus_per_worker: 1
num_cpus_per_worker: 4
resources_per_worker:
accelerator_type_a10: 0.01
accelerator_type_a100: 0.01
4 changes: 2 additions & 2 deletions models/mosaicml--mpt-7b-chat.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
deployment_config:
autoscaling_config:
min_replicas: 2
initial_replicas: 2
min_replicas: 1
initial_replicas: 1
max_replicas: 8
target_num_ongoing_requests_per_replica: 1.0
metrics_interval_s: 10.0
Expand Down

0 comments on commit 4c6b7ce

Please sign in to comment.