-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
66 lines (57 loc) · 1.97 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import warnings
warnings.filterwarnings("ignore")
from langchain.llms import HuggingFacePipeline
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
class LLM:
def __init__(self, model_name):
self.model_name = model_name
self.model_config = {
'llama': {
'tokenizer': 'aleksickx/llama-7b-hf',
'model': 'aleksickx/llama-7b-hf',
'T': 0.1
},
'bloom': {
'tokenizer': 'bigscience/bloom-7b1',
'model': 'bigscience/bloom-7b1',
'T': 0
},
'falcon': {
'tokenizer': 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
'model': 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
'T': 0
}
}
def get_model(self):
if self.model_name not in self.model_config:
print("Given model is not available!")
return None, None, None, None, None
config = self.model_config[self.model_name]
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
model = AutoModelForCausalLM.from_pretrained(
config['model'],
load_in_8bit=True,
device_map='auto',
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
max_len = 1024
task = "text-generation"
T = config['T']
return tokenizer, model, max_len, task, T
def get_pipeline(self):
tokenizer, model, max_len, task, T = self.get_model()
if tokenizer is None or model is None:
return None
pipe = pipeline(
task=task,
model=model,
tokenizer=tokenizer,
max_length=max_len,
temperature=T,
top_p=0.95,
repetition_penalty=1.15
)
return HuggingFacePipeline(pipeline=pipe)