-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathTextGeneration.py
64 lines (58 loc) · 1.99 KB
/
TextGeneration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import os
base_dir = os.path.dirname(os.path.abspath(__file__))
venv_site_packages = os.path.join(base_dir, 'venv', 'Lib', 'site-packages')
sys.path.append(venv_site_packages)
from autogen import oai
class TextGeneration:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"LLM": ("LLM",),
"use_cache": ("STRING", {"default": "False"}),
"timeout": ("INT", {"default": 120}),
"system_message": ("STRING", {"default": "You are a helpful AI assistant"}),
"Prompt": ("STRING", {
"multiline": True,
"default": "Your Prompt"
}),
}
}
RETURN_TYPES = ("TEXT",)
FUNCTION = "execute"
CATEGORY = "AutoGen"
def execute(self, LLM, Prompt, system_message, use_cache, timeout):
if use_cache == "True":
use_cache=True
else:
use_cache=False
if LLM['llama-cpp']==True:
llm=LLM['LLM']
response=llm.create_chat_completion(
messages = [
{"role": "system", "content": system_message},
{
"role": "user",
"content": Prompt
}
]
)
else:
response = oai.ChatCompletion.create(
config_list = LLM['LLM'],
messages=[
{"role": "system",
"content": system_message},
{"role": "user", "content": Prompt}],
use_cache=use_cache,
timeout=timeout,
)
response = response['choices'][0]['message']['content']
return ({"TEXT": response},)
NODE_CLASS_MAPPINGS = {
"TextGeneration": TextGeneration,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TextGeneration": "TextGeneration"
}