From f50b479a038ea5feee54c8f7aad8129239e26615 Mon Sep 17 00:00:00 2001 From: Geun Han Chung Date: Sun, 24 Nov 2024 10:24:20 -0500 Subject: [PATCH] bug fix --- backend/generate_answer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/generate_answer.py b/backend/generate_answer.py index 3621e94..bb4b172 100644 --- a/backend/generate_answer.py +++ b/backend/generate_answer.py @@ -11,7 +11,7 @@ def load_base_model(): # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME, token=Config.HUGGINGFACE_ACCESS_TOKEN, - max_new_tokens=300).to(device) + max_new_tokens=300) # Load model in 8-bit to reduce memory usage base_model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME, @@ -45,7 +45,8 @@ def get_device(): return "cuda" if torch.cuda.is_available() else "cpu" def generate_initial_note(page_content, model, tokenizer): - inputs = tokenizer(f"Generate notes for the given content: {page_content}", return_tensors="pt") + device = "cuda" if torch.cuda.is_available() else "cpu" + inputs = tokenizer(f"Generate notes for the given content: {page_content}", return_tensors="pt").to(device) outputs = model.generate( **inputs, max_length=2048, @@ -55,7 +56,7 @@ def generate_initial_note(page_content, model, tokenizer): ) final_output = "" for output in outputs: - final_output += tokenizer.decode(output, skip_special_tokens=True) + final_output += tokenizer.decode(output, skip_special_tokens=True).to(device) return final_output def generate_note(page_content, note_content, model, tokenizer):