Skip to content

Commit

Permalink
fixes to LLM generation code
Browse files Browse the repository at this point in the history
  • Loading branch information
eisenzopf committed May 14, 2024
1 parent 9aa673b commit 5345d7c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions llm_eval/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ def generate_output(self, text):
{"role": "user", "content": text },
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens, do_sample=True, temperature=self.temperature, top_p=self.top_p)
responses = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
responses = ' '.join(responses)
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device)
outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens, do_sample=True, temperature=self.temperature, top_p=self.top_p)
responses = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return prompt, responses

def load_dataset(self, dataset):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llm-eval"
version = "0.3.16"
version = "0.3.17"
authors = [
{name = "Jonathan Eisenzopf", email = "[email protected]"},
]
Expand Down

0 comments on commit 5345d7c

Please sign in to comment.