Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Geun Han Chung authored and Geun Han Chung committed Nov 24, 2024
1 parent 5e75ac5 commit f50b479
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions backend/generate_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def load_base_model():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME,
token=Config.HUGGINGFACE_ACCESS_TOKEN,
max_new_tokens=300).to(device)
max_new_tokens=300)

# Load model in 8-bit to reduce memory usage
base_model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME,
Expand Down Expand Up @@ -45,7 +45,8 @@ def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"

def generate_initial_note(page_content, model, tokenizer):
inputs = tokenizer(f"Generate notes for the given content: {page_content}", return_tensors="pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = tokenizer(f"Generate notes for the given content: {page_content}", return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_length=2048,
Expand All @@ -55,7 +56,7 @@ def generate_initial_note(page_content, model, tokenizer):
)
final_output = ""
for output in outputs:
final_output += tokenizer.decode(output, skip_special_tokens=True)
final_output += tokenizer.decode(output, skip_special_tokens=True).to(device)
return final_output

def generate_note(page_content, note_content, model, tokenizer):
Expand Down

0 comments on commit f50b479

Please sign in to comment.