From 8aac808e99eb72006a75d93a538e01871cce2cb7 Mon Sep 17 00:00:00 2001 From: voidful Date: Mon, 17 Jun 2024 16:14:04 +0800 Subject: [PATCH] fix: remove redundant torch_dtype parameter in LMUtil constructor --- nlp2/lm.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nlp2/lm.py b/nlp2/lm.py index bc1bcd3..46c50e0 100644 --- a/nlp2/lm.py +++ b/nlp2/lm.py @@ -14,7 +14,6 @@ def __init__(self, model_name="gpt2", tokenizer=None, model=None, device=None, - torch_dtype=torch.float16, device_map="auto"): if not tokenizer: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -25,7 +24,8 @@ def __init__(self, model_name="gpt2", else: self.device = device if not model: - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype, + self.model = AutoModelForCausalLM.from_pretrained(model_name, + torch_dtype=torch.float16, device_map=device_map) else: self.model = model diff --git a/setup.py b/setup.py index 1ad5d21..5cab251 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='nlp2', - version='1.9.1', + version='1.9.2', description='Tool for NLP - handle file and text', long_description="Github : https://github.com/voidful/nlp2", url='https://github.com/voidful/nlp2',