From 31887a4881b247538b5f465fd89e0dd1601ed4c5 Mon Sep 17 00:00:00 2001 From: RCW2000 <70165226+RCW2000@users.noreply.github.com> Date: Sat, 2 Mar 2024 16:53:55 -0500 Subject: [PATCH] Update prompts.py --- helpers/prompts.py | 45 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/helpers/prompts.py b/helpers/prompts.py index 6f7b5c3..8ebaaf8 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -1,10 +1,15 @@ import json #beam try import transformers -from transformers import Bpipeline +from transformers import pipeline import torch -bmodel = pipeline("text-generation") +pipe = pipeline( + "text-generation", + model="HuggingFaceH4/zephyr-7b-gemma-v0.1", + device_map="auto", + torch_dtype=torch.bfloat16, +) def extractConcepts(prompt: str, metadata): SYS_PROMPT = ( @@ -23,7 +28,23 @@ def extractConcepts(prompt: str, metadata): "]\n" ) bprompt="Using this: "+prompt+"/n"+SYS_PROMPT - response=bmodel(bprompt) + messages = [ + { + "role": "system", + "content": "", # Model not yet trained for follow this + }, + {"role": "user", "content": bprompt}, + ] + outputs = pipe( + messages, + max_new_tokens=128, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=0.95, + stop_sequence="<|im_end|>", + ) + response=outputs[0]["generated_text"][-1]["content"] try: result = json.loads(response) result = [dict(item, **metadata) for item in result] @@ -61,7 +82,23 @@ def graphPrompt(input: str,metadata): USER_PROMPT = f"context: ```{input}``` \n\n output: " bprompt="Using this: "+USER_PROMPT+"/n"+SYS_PROMPT - response=bmodel(bprompt) + messages = [ + { + "role": "system", + "content": "", # Model not yet trained for follow this + }, + {"role": "user", "content": bprompt}, + ] + outputs = pipe( + messages, + max_new_tokens=128, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=0.95, + stop_sequence="<|im_end|>", + ) + response=outputs[0]["generated_text"][-1]["content"] print(response) try: result = json.loads(response)