From 93ed0599ec688e721781fdb2d3e401c3cff598cc Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 15 Jan 2024 06:32:14 +0000 Subject: [PATCH] feat(huixiangdou/service/feature_store.py): support language option --- .gitignore | 1 + huixiangdou/service/feature_store.py | 7 +++++-- huixiangdou/service/worker.py | 2 +- tests/test_internlm2.py | 2 -- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index f0df30cf..789222d5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ models/ repodir/ workdir/ write_toml.py +modeling_internlm2.py config.ini config-template.ini logs/ diff --git a/huixiangdou/service/feature_store.py b/huixiangdou/service/feature_store.py index 24e8e7e7..5671b4bf 100644 --- a/huixiangdou/service/feature_store.py +++ b/huixiangdou/service/feature_store.py @@ -25,10 +25,12 @@ class FeatureStore: def __init__(self, device: str = 'cuda', - config_path: str = 'config.ini') -> None: + config_path: str = 'config.ini', + language: str = 'zh') -> None: """Init with model device type and config.""" self.config_path = config_path self.reject_throttle = -1 + self.language = language with open(config_path, encoding='utf8') as f: config = pytoml.load(f)['feature_store'] model_path = config['model_path'] @@ -153,7 +155,8 @@ def ingress_response(self, markdown_dir: str, work_dir: str): full_text = str(p).rsplit('/_', maxsplit=1)[-1] + '\n' + f.read() if '.md' in str(p): - if not self.is_chinese_doc(full_text): + if self.language == 'zh' and not self.is_chinese_doc( + full_text): # noqa E501 continue full_texts.append(full_text) diff --git a/huixiangdou/service/worker.py b/huixiangdou/service/worker.py index 3ccd491c..af30e170 100644 --- a/huixiangdou/service/worker.py +++ b/huixiangdou/service/worker.py @@ -41,7 +41,7 @@ def __init__(self, work_dir: str, config_path: str, language: str = 'zh'): language (str, optional): Specifies the language to be used. Defaults to 'zh' (Chinese). # noqa E501 """ self.llm = ChatClient(config_path=config_path) - self.fs = FeatureStore(config_path=config_path) + self.fs = FeatureStore(config_path=config_path, language=language) self.fs.load_feature(work_dir=work_dir) self.config_path = config_path self.config = None diff --git a/tests/test_internlm2.py b/tests/test_internlm2.py index 2be761bb..3dcfe39d 100644 --- a/tests/test_internlm2.py +++ b/tests/test_internlm2.py @@ -10,8 +10,6 @@ device_map='auto', torch_dtype='auto').eval() -# 不能像某些 LLM 一样 AutoModelForCausalLM.from_pretrained(.. fp16=True) 这样写,会 Internlm2Config.__init__() 报错 - queries = ['how to install mmdeploy ?'] for query in queries: pdb.set_trace()