Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Fix loading without S3 (#112)
Browse files Browse the repository at this point in the history
Fixes #9

Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored Jun 6, 2023
1 parent bfdd10b commit 0ececcc
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 24 deletions.
44 changes: 37 additions & 7 deletions aviary/backend/llm/initializers/hf_transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def _get_model_location_on_disk(self, model_id: str) -> str:
if os.path.exists(path):
with open(os.path.join(path, "refs", "main"), "r") as f:
snapshot_hash = f.read().strip()
if os.path.exists(os.path.join(path, "snapshots", snapshot_hash)):
if os.path.exists(
os.path.join(path, "snapshots", snapshot_hash)
) and os.path.exists(
os.path.join(path, "snapshots", snapshot_hash, "config.json")
):
model_id_or_path = os.path.join(path, "snapshots", snapshot_hash)
return model_id_or_path

Expand All @@ -85,11 +89,24 @@ def load_model(self, model_id: str) -> "PreTrainedModel":
model_id (str): Hugging Face model ID.
"""
model_id_or_path = self._get_model_location_on_disk(model_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()

logger.info(f"Loading model {model_id_or_path}...")
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **self._get_model_from_pretrained_kwargs()
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
)
except OSError:
if model_id_or_path != model_id:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
)
model = AutoModelForCausalLM.from_pretrained(
model_id, **from_pretrained_kwargs
)
else:
raise
model.eval()
return model

Expand All @@ -100,15 +117,28 @@ def load_tokenizer(self, tokenizer_id: str) -> "PreTrainedTokenizer":
tokenizer_id (str): Hugging Face tokenizer name.
"""
tokenizer_id_or_path = self._get_model_location_on_disk(tokenizer_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()

# TODO make this more robust, add logging
# TODO make this more robust
try:
return AutoTokenizer.from_pretrained(
tokenizer_id_or_path, padding_side="left", trust_remote_code=True
tokenizer_id_or_path,
padding_side="left",
trust_remote_code=from_pretrained_kwargs.get(
"trust_remote_code", False
),
)
except Exception:
logger.warning(
f"Couldn't load tokenizer from derived path {tokenizer_id_or_path}, "
f"trying to load from model_id {tokenizer_id}"
)
return AutoTokenizer.from_pretrained(
tokenizer_id, padding_side="left", trust_remote_code=True
tokenizer_id,
padding_side="left",
trust_remote_code=from_pretrained_kwargs.get(
"trust_remote_code", False
),
)

def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel":
Expand Down
67 changes: 52 additions & 15 deletions aviary/backend/llm/initializers/hf_transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class DeepSpeedInitializer(TransformersInitializer):
use_meta_tensor (bool, optional): Whether to use meta tensor loading method. Defaults to False.
injection_policy ([type], optional): Injection policy for DeepSpeed AutoTP. Cannot
be set if use_kernel=True. Defaults to None.
ds_inference_kwargs (Dict[str, Any], optional): Other keyword arguments for ``deepspeed.initialize``.
Specific arguments in the signature of this function will override these values.
**from_pretrained_kwargs: Keyword arguments for ``AutoModel.from_pretrained``.
"""

Expand All @@ -44,6 +46,7 @@ def __init__(
use_kernel: bool = False,
use_meta_tensor: bool = False,
injection_policy=None,
ds_inference_kwargs: Optional[Dict[str, Any]] = None,
**from_pretrained_kwargs,
):
super().__init__(
Expand All @@ -59,6 +62,7 @@ def __init__(
self.use_meta_tensor = use_meta_tensor
# TODO: Allow conversion from strings (need to do dynamic imports)
self.injection_policy = injection_policy
self.ds_inference_kwargs = ds_inference_kwargs

if self.use_kernel:
assert not (self.use_bettertransformer or self.torch_compile)
Expand Down Expand Up @@ -115,30 +119,58 @@ def _generate_checkpoint_json(

def load_model(self, model_id: str) -> "PreTrainedModel":
model_id_or_path = self._get_model_location_on_disk(model_id)
from_pretrained_kwargs = self._get_model_from_pretrained_kwargs()

logger.info(f"Loading model {model_id_or_path}...")
if self.use_meta_tensor:
logger.info("Loading model using DeepSpeed meta tensor...")
config = AutoConfig.from_pretrained(
model_id_or_path, **self._get_model_from_pretrained_kwargs()
)

try:
config = AutoConfig.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
)
except OSError:
if model_id_or_path != model_id:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
)
config = AutoConfig.from_pretrained(
model_id, **from_pretrained_kwargs
)
else:
raise

self._repo_root, self._checkpoints_json = self._generate_checkpoint_json(
model_id
)

with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **self._get_model_from_pretrained_kwargs()
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **from_pretrained_kwargs
)
except OSError:
if model_id_or_path != model_id:
logger.warning(
f"Couldn't load model from derived path {model_id_or_path}, "
f"trying to load from model_id {model_id}"
)
model = AutoModelForCausalLM.from_pretrained(
model_id, **from_pretrained_kwargs
)
else:
raise
model.eval()
return model

def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel":
from transformers import GPTNeoXForCausalLM, LlamaForCausalLM

injection_policy = self.injection_policy
# TODO: remove those later when deepspeed master is updated
if injection_policy is None and not self.use_kernel:
if isinstance(model, GPTNeoXForCausalLM):
from transformers import GPTNeoXLayer
Expand All @@ -159,20 +191,25 @@ def postprocess_model(self, model: "PreTrainedModel") -> "PreTrainedModel":
logger.info("Transforming the model with BetterTransformer...")
model = BetterTransformer.transform(model)

ds_kwargs = self.ds_inference_kwargs or {}
ds_kwargs = ds_kwargs.copy()
ds_kwargs.update(
dict(
dtype=self.dtype,
mp_size=self.world_size,
replace_with_kernel_inject=self.use_kernel,
injection_policy=injection_policy,
max_tokens=self.max_tokens,
)
)
if self.use_meta_tensor:
ds_kwargs = dict(
base_dir=self._repo_root, checkpoint=self._checkpoints_json
ds_kwargs.update(
dict(base_dir=self._repo_root, checkpoint=self._checkpoints_json)
)
else:
ds_kwargs = dict()

logger.info(f"deepspeed.init_inference kwargs: {ds_kwargs}")
model = deepspeed.init_inference(
model,
dtype=self.dtype,
mp_size=self.world_size,
replace_with_kernel_inject=self.use_kernel,
injection_policy=injection_policy,
max_tokens=self.max_tokens,
**ds_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion aviary/backend/llm/pipelines/default_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
token_stopper = next(
(
x
for x in model_outputs["generate_kwargs"]["stopping_criteria"]
for x in model_outputs["generate_kwargs"].get("stopping_criteria", [])
if isinstance(x, StopOnTokens)
),
None,
Expand Down
2 changes: 1 addition & 1 deletion aviary/backend/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def initialize_node(
torch_cache_home = _get_torch_home()
os.makedirs(os.path.join(torch_cache_home, "kernels"), exist_ok=True)

if model_id and s3_mirror_config:
if model_id and s3_mirror_config and s3_mirror_config.bucket_uri:
lock_path = os.path.expanduser(f"~/{model_id.replace('/', '--')}.lock")
try:
# Timeout 0 means there will be only one attempt to acquire
Expand Down
1 change: 1 addition & 0 deletions aviary/backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class DeepSpeed(Transformers):
use_kernel: bool = False
max_tokens: int = 1024
use_meta_tensor: bool = False
ds_inference_kwargs: Optional[Dict[str, Any]] = None

@root_validator
def use_kernel_bettertransformer_torch_compile(cls, values):
Expand Down

0 comments on commit 0ececcc

Please sign in to comment.