Skip to content

Commit

Permalink
export lmdeploy version
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Nov 14, 2023
1 parent f24c905 commit 371320a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/lite/utils/export_turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,7 @@ def export_turbomind_hf_model(model_name: str,
with open(config_file) as f:
data = json.load(f)
data['turbomind'] = config
from lmdeploy.version import __version__
data['lmdeploy_version'] = __version__
with open(config_file, 'w') as f:
f.write(json.dumps(data, indent=2) + '\n')
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/source_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def clean_up(self, last: bool) -> None:
for key in self.params:
layer_id = re.findall(self.attn_layer_patten, key)
if len(layer_id) == 0:
# tok, norm, output
to_remove.append(key)
else:
layer_id = int(layer_id[0])
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self,
assert self.cfg.valid
self.to_file = to_file
self.out_dir = out_dir
self.tm_params = {}

@abstractmethod
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
Expand All @@ -121,9 +122,6 @@ def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
final_cfg.update(dict(head_num=head_num, vocab_size=_vocab_size))
return TurbomindModelConfig.from_dict(final_cfg, allow_none=True)

def set_tm_params(self, tm_params):
self.tm_params = tm_params

def export_config(self) -> None:
"""export turbomind config."""
if self.to_file:
Expand All @@ -143,8 +141,8 @@ def export_weight(self, param: torch.Tensor, name: str) -> None:
tprint(name, param.shape)
param.contiguous().cpu().numpy().tofile(
osp.join(self.out_dir, name))
elif hasattr(self, 'tm_params'):
tm_params = getattr(self, 'tm_params')
elif len(self.tm_params) > 0:
tm_params = self.tm_params
weight_type = self.cfg.weight_type
assert weight_type in ['fp16', 'fp32', 'int4']

Expand All @@ -162,6 +160,8 @@ def export_weight(self, param: torch.Tensor, name: str) -> None:
for tm_tensor in tm_params[name]:
tm_tensor.copy_from(th_tensor)
tm_params.pop(name)
else:
tprint('skip export', name, param.shape)

def save_split(self,
tensor: torch.Tensor,
Expand Down
13 changes: 5 additions & 8 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def _create_weight(device_id):
for t in threads:
t.join()

# convert model to turbomind format if load a hf model
# convert model to turbomind format if loading a hf model
if model_source in [ModelSource.HF_LMDEPLOY, ModelSource.HF_MODEL]:
tm_params = self.get_model_params()
tm_params = output_model.tm_params
self.get_model_params(tm_params)
logger.warning(f'get {len(tm_params)} model params')
output_model.set_tm_params(tm_params)
output_model.export()
# load kv qparams
self.load_kv_qparams(model_path, tm_params, **kwargs)
Expand All @@ -169,7 +169,7 @@ def _create_weight(device_id):
def load_kv_qparams(self, model_path, tm_params, **kwargs):
"""Load kv qparams."""
if self.config.quant_policy:
logger.error('loading kv_cache quant scale')
logger.warning('loading kv_cache quant scale')
from lmdeploy.lite.apis.kv_qparams import main as kv_loader
kv_sym = kwargs.get('kv_sym', False)
kv_bits = kwargs.get('kv_bits', 8)
Expand All @@ -180,7 +180,7 @@ def load_kv_qparams(self, model_path, tm_params, **kwargs):
if 'past_kv_scale' in key:
tm_params.pop(key)

def get_model_params(self):
def get_model_params(self, tm_params):
"""Get turbomind model params."""

def _get_params(device_id, que):
Expand All @@ -198,16 +198,13 @@ def _get_params(device_id, que):
for t in threads:
t.join()

tm_params = {}
for _ in range(self.gpu_count):
tensor_map = que.get()
for k, v in tensor_map.items():
if k not in tm_params:
tm_params[k] = []
tm_params[k].append(v)

return tm_params

def create_common_from_hf(self,
model_source: ModelSource,
model_path: str,
Expand Down
1 change: 0 additions & 1 deletion src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
template<typename T>
TensorMap LlamaDecoderLayerWeight<T>::getParams(std::string prefix)
{
// TODO: support KV Cache INT8
TensorMap output;

output.insert(concat(prefix, "attention_norm.weight"),
Expand Down

0 comments on commit 371320a

Please sign in to comment.