forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
109 lines (95 loc) · 4.58 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from typing import List, Optional, TYPE_CHECKING, Tuple
import torch
from transformers import (
LlamaForCausalLM,
LlamaConfig,
)
from olive.models.config import ModelConfig
if TYPE_CHECKING:
from .training import TrainConfig
class LLaMaConfig(ModelConfig):
model_name: str
def initialize(self, train_config: "TrainConfig"):
if train_config.enable_fsdp:
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
this avoids cpu oom when loading large models like llama 70B, in which case
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
self.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
else:
llama_config = LlamaConfig.from_pretrained(self.model_name)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
self.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
return model
class TiedLLaMaConfig(ModelConfig):
model_name: str
# format as a list of tuples [(src, dst), ...] where the dst layer will be
# replaced with the src layer
tied_layers: Optional[List[Tuple[int, int]]] = None
def initialize(self, train_config: "TrainConfig"):
if train_config.enable_fsdp:
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
this avoids cpu oom when loading large models like llama 70B, in which case
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
self.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
else:
llama_config = LlamaConfig.from_pretrained(self.model_name)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
self.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
if self.tied_layers is not None:
for src, dst in self.tied_layers:
print(f"Swapping layer {src} with layer {dst}")
model.model.layers[dst] = model.model.layers[src]
return model