Skip to content

Commit

Permalink
fixed generate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
little51 committed Jul 28, 2023
1 parent 5210d0d commit 9834472
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 88 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ tail -f train.log
```bash
CUDA_VISIBLE_DEVICES=0 python generate.py \
--base_model './models/daryl149/llama-2-7b-chat-hf' \
--lora_weights 'output/checkpoint-2800' \
--lora_weights 'output/checkpoint-2000' \
--load_8bit #不加这个参数是用的4bit
```


182 changes: 96 additions & 86 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,63 @@
import os
import sys
import argparse

import argparse
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer,BitsAndBytesConfig
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.3,
max_new_tokens=100
)

def generate() :
# parse args
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_weights', default=None, type=str,
help="If None, perform inference on the base model")
parser.add_argument('--load_8bit', action='store_true',
help='only use CPU for inference')
args = parser.parse_args()
if args.load_8bit is None:
load_8bit = False
else:
load_8bit = args.load_8bit
if args.base_model is None:
base_model = "./model/llama-7b"
else:
base_model = args.base_model
if args.lora_weights is None:
lora_weights = "./model/llama-peft"
else:
lora_weights = args.lora_weights

bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True
)

bnb_config_8bit = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)

device_map_cpu = {
"transformer.word_embeddings": "cpu",
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": "cpu",
"transformer.h": "cpu",
"transformer.ln_f": "cpu",
"model.embed_tokens": "cpu",
"model.layers":"cpu",
"model.norm":"cpu"
}

device_map_cpu = {"": "cpu"}

device_map_gpu = "auto"

# load model
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass


def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = ""
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
quantization_config=bnb_config_4bit if load_8bit else bnb_config_8bit,
torch_dtype=torch.float16,
device_map=device_map_gpu if load_8bit else device_map_cpu
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
Expand All @@ -97,35 +68,74 @@ def generate() :
device_map={"": device},
)

model.config.pad_token_id = tokenizer.pad_token_id = 0
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2

if not load_8bit:
model.half()
model.half() # seems to fix bugs for some users.

model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)

# generate
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
stream_output=False,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)

generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}

with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)

while True:
input_text = input("Input:")
if len(input_text.strip()) == 0:
instruction = input("Input:")
if len(instruction.strip()) == 0:
break
inputs = tokenizer(input_text, return_tensors="pt")
generation_output = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**generation_config
)
s = generation_output[0]
response = tokenizer.decode(s, skip_special_tokens=True)
print("Response: ", response)
print("\n")

if __name__ == '__main__':
with torch.autocast("cuda"):
generate()
print("Response:", evaluate(instruction))


if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_weights', default=None, type=str,
help="If None, perform inference on the base model")
parser.add_argument('--load_8bit', action='store_true',
help='only use CPU for inference')
args = parser.parse_args()
main(args.load_8bit, args.base_model, args.lora_weights, "")

0 comments on commit 9834472

Please sign in to comment.