Skip to content

Commit

Permalink
Update generate_answer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirAgassi committed Nov 24, 2024
1 parent b37fd26 commit 9fc8db1
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions backend/generate_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from rl_model import RLModel

def load_base_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME,
token=Config.HUGGINGFACE_ACCESS_TOKEN,
max_new_tokens=8096)

# Load model in 8-bit to reduce memory usage
base_model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME,
token=Config.HUGGINGFACE_ACCESS_TOKEN)
token=Config.HUGGINGFACE_ACCESS_TOKEN).to(device)

return base_model, tokenizer

Expand Down Expand Up @@ -58,14 +59,27 @@ def generate_initial_note(page_content, model, tokenizer):
return final_output

def generate_note(page_content, note_content, model, tokenizer):
# get device and ensure model is on it
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

if isinstance(model, RLModel):
state = model.get_state(page_content, note_content)
outputs = model.act(state)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# fallback to original implementation
inputs = tokenizer(f"I did not like this note: {note_content}. Generate new notes for the given content: {page_content}", return_tensors="pt")
outputs = model.generate(**inputs)
# move inputs to same device as model
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model.generate(
**inputs,
max_length=2048,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
temperature=0.7,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_quiz_review(origin_content, wrong_questions, model, tokenizer):
Expand Down

0 comments on commit 9fc8db1

Please sign in to comment.