Skip to content

Commit

Permalink
Update prompts.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RCW2000 authored Mar 2, 2024
1 parent d4e3efe commit 31887a4
Showing 1 changed file with 41 additions and 4 deletions.
45 changes: 41 additions & 4 deletions helpers/prompts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json
#beam try
import transformers
from transformers import Bpipeline
from transformers import pipeline
import torch

bmodel = pipeline("text-generation")
pipe = pipeline(
"text-generation",
model="HuggingFaceH4/zephyr-7b-gemma-v0.1",
device_map="auto",
torch_dtype=torch.bfloat16,
)

def extractConcepts(prompt: str, metadata):
SYS_PROMPT = (
Expand All @@ -23,7 +28,23 @@ def extractConcepts(prompt: str, metadata):
"]\n"
)
bprompt="Using this: "+prompt+"/n"+SYS_PROMPT
response=bmodel(bprompt)
messages = [
{
"role": "system",
"content": "", # Model not yet trained for follow this
},
{"role": "user", "content": bprompt},
]
outputs = pipe(
messages,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
stop_sequence="<|im_end|>",
)
response=outputs[0]["generated_text"][-1]["content"]
try:
result = json.loads(response)
result = [dict(item, **metadata) for item in result]
Expand Down Expand Up @@ -61,7 +82,23 @@ def graphPrompt(input: str,metadata):

USER_PROMPT = f"context: ```{input}``` \n\n output: "
bprompt="Using this: "+USER_PROMPT+"/n"+SYS_PROMPT
response=bmodel(bprompt)
messages = [
{
"role": "system",
"content": "", # Model not yet trained for follow this
},
{"role": "user", "content": bprompt},
]
outputs = pipe(
messages,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
stop_sequence="<|im_end|>",
)
response=outputs[0]["generated_text"][-1]["content"]
print(response)
try:
result = json.loads(response)
Expand Down

0 comments on commit 31887a4

Please sign in to comment.