From 788da826ee28bf051c1b119f7dec509fd0f489e4 Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Tue, 19 Sep 2023 06:43:34 +0000 Subject: [PATCH 1/7] modify gitignore and fix the bug when run humaneval --- .gitignore | 4 ++++ modeling.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a26ef87..52961fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ /.idea/ *.jsonl +*.csv + +__pycache__/ +data/ diff --git a/modeling.py b/modeling.py index 53bad94..9e3ae16 100644 --- a/modeling.py +++ b/modeling.py @@ -153,10 +153,12 @@ def load(self): def run(self, prompt: str, **kwargs) -> str: self.load() inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + do_sample = kwargs.pop("do_sample", self.do_sample) + max_output_length = kwargs.pop("max_output_length", self.max_output_length) outputs = self.model.generate( **inputs, - max_length=self.max_output_length, - do_sample=self.do_sample, + max_length=max_output_length, + do_sample=do_sample, **kwargs, ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) From 3954b6f161c38dde8c7b39d4f52114652bef877e Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Thu, 21 Sep 2023 09:01:02 +0000 Subject: [PATCH 2/7] enable multi-GPU support --- modeling.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/modeling.py b/modeling.py index 53bad94..c08a402 100644 --- a/modeling.py +++ b/modeling.py @@ -141,12 +141,14 @@ def load(self): args = {} if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) + else: + args.update(device_map="auto", torch_dtype=torch.float16) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) if self.lora_path: self.model = PeftModel.from_pretrained(self.model, self.lora_path) self.model.eval() - if not self.load_8bit: - self.model.to(self.device) + # if not self.load_8bit: + # self.model.to(self.device) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) @@ -190,12 +192,14 @@ def load(self): args = {} if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) + else: + args.update(device_map="auto", torch_dtype=torch.float16) self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) self.model.eval() - if not self.load_8bit: - self.model.to(self.device) + # if not self.load_8bit: + # self.model.to(self.device) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True @@ -247,12 +251,14 @@ def load(self): args = {} if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) + else: + args.update(device_map="auto", torch_dtype=torch.float16) self.model = LlamaForCausalLM.from_pretrained(self.model_path, **args) if self.lora_path: self.model = PeftModel.from_pretrained(self.model, self.lora_path) self.model.eval() - if not self.load_8bit: - self.model.to(self.device) + # if not self.load_8bit: + # self.model.to(self.device) def run(self, prompt: str, **kwargs) -> str: if self.use_template: From ac5654215f6ae16d31aab36fd0a9f6f31798eab5 Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Sat, 23 Sep 2023 03:38:54 +0000 Subject: [PATCH 3/7] add vllm --- bbh.py | 12 +++++----- mmlu.py | 14 +++++++----- modeling.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/bbh.py b/bbh.py index 6160061..642ec0b 100644 --- a/bbh.py +++ b/bbh.py @@ -51,6 +51,7 @@ def evaluate(model: EvalModel, data: BBHData, ntrain: int) -> dict: data_test = BBHData(samples=data.samples[ntrain:]) is_correct = [] + prompts = [] for i in range(len(data_test.samples)): # get prompt and make sure it fits k = int(ntrain) @@ -63,11 +64,12 @@ def evaluate(model: EvalModel, data: BBHData, ntrain: int) -> dict: train_prompt = gen_prompt(data_train, k) prompt = train_prompt + prompt_end - label = data_test.samples[i].target - pred = model.run(prompt) - is_correct.append(pred.strip().startswith(label)) - if i == 0: - print(dict(prompt=prompt, label=label, pred=pred)) + prompts.append(prompt) + + preds = model.run(prompts) + is_correct.extend([pred.strip().startswith(label) for pred, label in zip(preds, data_test.target)]) + if i == 0: + print(dict(prompt=prompt, label=label, pred=pred)) return dict(score=sum(is_correct) / len(is_correct)) diff --git a/mmlu.py b/mmlu.py index 7adcd7b..de53e6e 100644 --- a/mmlu.py +++ b/mmlu.py @@ -137,6 +137,8 @@ def evaluate(args, subject, model: EvalModel, dev_df, test_df): cors = [] all_probs = [] + labels = [] + prompts = [] for i in range(test_df.shape[0]): # get prompt and make sure it fits k = args.ntrain @@ -150,11 +152,13 @@ def evaluate(args, subject, model: EvalModel, dev_df, test_df): prompt = train_prompt + prompt_end label = test_df.iloc[i, test_df.shape[1] - 1] - pred = model.run(prompt) - probs = [0 for _ in get_choices()] - cor = pred.strip().startswith(label) - cors.append(cor) - all_probs.append(probs) + prompts.append(prompt) + labels.append(label) + + preds = model.run(prompts) + probs = [0 for _ in get_choices()] + cors = [pred.strip().startswith(label) for pred, label in zip(preds, labels)] + all_probs.extend([probs for _ in preds]) acc = np.mean(cors) cors = np.array(cors) diff --git a/modeling.py b/modeling.py index f99d660..aeaf40f 100644 --- a/modeling.py +++ b/modeling.py @@ -27,8 +27,10 @@ AutoModel, LlamaConfig, ) +from vllm import LLM, SamplingParams import quant +from fastchat.model import get_conversation_template class EvalModel(BaseModel, arbitrary_types_allowed=True): @@ -125,6 +127,68 @@ def handler(signum, frame): else: time.sleep(3) return "Z" + + +class vllmModel(EvalModel): + model_path: str + # template_name: str + trust_remote_code: bool = True + model: Optional[LLM] + tokenizer: Optional[PreTrainedTokenizer] + load_8bit: bool = False + temperature: float = 0.0 + tensor_parallel_size: int = 1 + + def load(self): + if self.model is None: + args = {} + if self.load_8bit: + args.update(device_map="auto", load_in_8bit=True) + else: + args.update(device_map="auto", torch_dtype=torch.float16) + self.model = LLM( + model=self.model_path, + trust_remote_code=self.trust_remote_code, + tensor_parallel_size=self.tensor_parallel_size + ) + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=self.trust_remote_code + ) + + def count_text_length(self, text: str) -> int: + self.load() + return len(self.tokenizer(text).input_ids) + + def run(self, prompts: str, **kwargs) -> str: + self.load() + # new_prompts = [] + # for prompt in prompts: + # conv_template = get_conversation_template(self.template_name) + # conv_template.set_system_message( + # "You are an AI assistant. Please provide helpful, detailed, and polite answers to the user's questions." + # ) + # conv_template.append_message(conv_template.roles[0], prompt) + # conv_template.append_message(conv_template.roles[1], None) + # new_prompts.append(conv_template.get_prompt()) + + do_sample = kwargs.pop("do_sample", True) + max_output_length = kwargs.pop("max_output_length", self.max_output_length) + temperature = 0 if not do_sample else kwargs.pop("temperature", self.temperature) + + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_output_length, + **kwargs + ) + outputs = self.model.generate( + prompts, sampling_params + ) + return [output.outputs[0].text for output in outputs] + + def get_choice(self, text: str, **kwargs) -> Tuple[float, float]: + raise NotImplementedError class SeqToSeqModel(EvalModel): @@ -506,6 +570,7 @@ def select_model(model_name: str, **kwargs) -> EvalModel: openai=OpenAIModel, rwkv=RWKVModel, gptq=GPTQModel, + vllm=vllmModel ) model_class = model_map.get(model_name) if model_class is None: From c6653b7782985a8f242f41523bc3c148d4179c09 Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Sun, 24 Sep 2023 07:58:11 +0000 Subject: [PATCH 4/7] add rslt path --- main.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/main.py b/main.py index ff8773c..0afd13a 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,8 @@ from fire import Fire +from pathlib import Path +import json + import bbh import crass import drop @@ -9,6 +12,11 @@ def main(task_name: str, **kwargs): + rslt_path = kwargs.pop('rslt_path', None) + if rslt_path is not None and Path(rslt_path).exists(): + print(f"Already have file in {rslt_path}. Exist.") + exit(0) + task_map = dict( mmlu=mmlu.main, bbh=bbh.main, @@ -46,6 +54,12 @@ def main(task_name: str, **kwargs): results = {name: round(score * 100, 2) for name, score in results.items()} print(results) + + if rslt_path is not None: + Path(rslt_path).parent.mkdir(exist_ok=True, parents=True) + with open(rslt_path, 'w') as f: + json.dump(results, f) + return results From ae62d080e916a96624d7cde459f589061b847950 Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Sat, 7 Oct 2023 09:22:26 +0000 Subject: [PATCH 5/7] add batch size for causual model --- .vscode/settings.json | 6 ++++ modeling.py | 65 ++++++++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 26 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..d99f2f3 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/modeling.py b/modeling.py index aeaf40f..93c15ea 100644 --- a/modeling.py +++ b/modeling.py @@ -127,7 +127,7 @@ def handler(signum, frame): else: time.sleep(3) return "Z" - + class vllmModel(EvalModel): model_path: str @@ -149,12 +149,11 @@ def load(self): self.model = LLM( model=self.model_path, trust_remote_code=self.trust_remote_code, - tensor_parallel_size=self.tensor_parallel_size + tensor_parallel_size=self.tensor_parallel_size, ) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, - trust_remote_code=self.trust_remote_code + self.model_path, trust_remote_code=self.trust_remote_code ) def count_text_length(self, text: str) -> int: @@ -175,18 +174,16 @@ def run(self, prompts: str, **kwargs) -> str: do_sample = kwargs.pop("do_sample", True) max_output_length = kwargs.pop("max_output_length", self.max_output_length) - temperature = 0 if not do_sample else kwargs.pop("temperature", self.temperature) + temperature = ( + 0 if not do_sample else kwargs.pop("temperature", self.temperature) + ) sampling_params = SamplingParams( - temperature=temperature, - max_tokens=max_output_length, - **kwargs - ) - outputs = self.model.generate( - prompts, sampling_params + temperature=temperature, max_tokens=max_output_length, **kwargs ) + outputs = self.model.generate(prompts, sampling_params) return [output.outputs[0].text for output in outputs] - + def get_choice(self, text: str, **kwargs) -> Tuple[float, float]: raise NotImplementedError @@ -253,6 +250,8 @@ def get_choice(self, text: str, **kwargs) -> Tuple[float, float]: class CausalModel(SeqToSeqModel): + batch_size: int = 1 + def load(self): if self.model is None: args = {} @@ -268,24 +267,38 @@ def load(self): # self.model.to(self.device) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, trust_remote_code=True + self.model_path, trust_remote_code=True, padding_side="left" ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token def run(self, prompt: str, **kwargs) -> str: self.load() - inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) - if "RWForCausalLM" in str(type(self.model)): - inputs.pop("token_type_ids") # Not used by Falcon model + all_outputs = [] - outputs = self.model.generate( - **inputs, - max_new_tokens=self.max_output_length, - pad_token_id=self.tokenizer.eos_token_id, # Avoid pad token warning - do_sample=self.do_sample, - **kwargs, - ) - batch_size, length = inputs.input_ids.shape - return self.tokenizer.decode(outputs[0, length:], skip_special_tokens=True) + for i in range(0, len(prompt), self.batch_size): + batch_prompt = prompt[i : i + self.batch_size] + inputs = self.tokenizer(batch_prompt, return_tensors="pt", padding=True).to( + self.device + ) + if "RWForCausalLM" in str(type(self.model)): + inputs.pop("token_type_ids") # Not used by Falcon model + + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_output_length, + pad_token_id=self.tokenizer.eos_token_id, # Avoid pad token warning + do_sample=self.do_sample, + **kwargs, + ) + batch_size, length = inputs.input_ids.shape + all_outputs.extend( + self.tokenizer.batch_decode( + outputs[:, length:], skip_special_tokens=True + ) + ) + + return all_outputs def get_choice(self, text: str, **kwargs) -> Tuple[float, float]: self.load() @@ -570,7 +583,7 @@ def select_model(model_name: str, **kwargs) -> EvalModel: openai=OpenAIModel, rwkv=RWKVModel, gptq=GPTQModel, - vllm=vllmModel + vllm=vllmModel, ) model_class = model_map.get(model_name) if model_class is None: From 92190d0ff424e67b61fa78312d90ea459c0c864b Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Sat, 7 Oct 2023 15:58:19 +0000 Subject: [PATCH 6/7] fix other dataset --- bbh.py | 9 ++++++--- crass.py | 14 ++++++++++---- drop.py | 16 ++++++++++++---- human_eval/main.py | 38 ++++++++++++++++++++++++-------------- modeling.py | 7 +++++-- 5 files changed, 57 insertions(+), 27 deletions(-) diff --git a/bbh.py b/bbh.py index 642ec0b..098e7f7 100644 --- a/bbh.py +++ b/bbh.py @@ -67,9 +67,12 @@ def evaluate(model: EvalModel, data: BBHData, ntrain: int) -> dict: prompts.append(prompt) preds = model.run(prompts) - is_correct.extend([pred.strip().startswith(label) for pred, label in zip(preds, data_test.target)]) - if i == 0: - print(dict(prompt=prompt, label=label, pred=pred)) + labels = [s.target for s in data_test.samples] + + print(dict(prompt=prompt, label=labels[0], pred=preds[0])) + is_correct.extend( + [pred.strip().startswith(label) for pred, label in zip(preds, labels)] + ) return dict(score=sum(is_correct) / len(is_correct)) diff --git a/crass.py b/crass.py index 2301e20..03e1102 100644 --- a/crass.py +++ b/crass.py @@ -141,6 +141,8 @@ def evaluate(model: EvalModel, data_train: CrassData, data_test: CrassData) -> d progress = tqdm(data_test.samples) sample: CrassSample + + prompts, labels = [], [] for sample in progress: # get prompt and make sure it fits k = int(len(data_train.samples)) @@ -154,12 +156,16 @@ def evaluate(model: EvalModel, data_train: CrassData, data_test: CrassData) -> d prompt = train_prompt + prompt_end label = sample.get_answer_label() - pred = model.run(prompt).strip() + prompts.append(prompt) + labels.append(label) + + preds = model.run(prompts) + preds = [i.strip() for i in preds] + for label, pred in zip(labels, preds): is_correct.append(pred.startswith(label)) - score = sum(is_correct) / len(is_correct) - progress.set_postfix(score=score) - print(dict(prompt=prompt, label=label, pred=pred)) + score = sum(is_correct) / len(is_correct) + print(dict(prompt=prompts[0], label=labels[0], pred=preds[0])) return dict(score=score) diff --git a/drop.py b/drop.py index 827b22d..a26a8db 100644 --- a/drop.py +++ b/drop.py @@ -103,6 +103,9 @@ def evaluate(model: EvalModel, data: DropData, ntrain: int) -> dict: progress = tqdm(data_test.samples) sample: DropSample + + prompts = [] + labels = [] for sample in progress: # get prompt and make sure it fits k = int(ntrain) @@ -116,11 +119,16 @@ def evaluate(model: EvalModel, data: DropData, ntrain: int) -> dict: prompt = train_prompt + prompt_end label = sample.get_answers()[0] - pred = model.run(prompt).strip() + prompts.append(prompt) + labels.append(label) + + preds = model.run(prompts) + preds = [i.strip() for i in preds] + for label, pred in zip(labels, preds): is_correct.append(pred.startswith(label)) - score = sum(is_correct) / len(is_correct) - progress.set_postfix(score=score) - print(dict(prompt=prompt, label=label, pred=pred)) + + score = sum(is_correct) / len(is_correct) + print(dict(prompt=prompts[0], label=labels[0], pred=preds[0])) return dict(score=score) diff --git a/human_eval/main.py b/human_eval/main.py index 5ea5063..20a9f93 100644 --- a/human_eval/main.py +++ b/human_eval/main.py @@ -78,29 +78,39 @@ def evaluate(model: EvalModel, data_path: str, **kwargs) -> dict: dataset = read_problems(data_path) n_sample = kwargs.get("n_sample", 1) best_temperature = {1: 0.1, 10: 0.6, 100: 0.8} + temperature = best_temperature[n_sample] + samples = [] progress_bar = tqdm(total=len(dataset) * n_sample, desc="Generating samples") + + prompts = [] + task_ids = [] for task_id in dataset: for i in range(n_sample): prompt = dataset[task_id]["prompt"] prompt = gen_prompt(prompt, model) - temperature = best_temperature[n_sample] - if temperature > 0: - completion = model.run(prompt, temperature=temperature, do_sample=True) - else: - completion = model.run(prompt) - - completion = fix_indents(completion) - sample = dict(task_id=task_id, completion=filter_code(completion, model)) - if i == 0: - print("Prompt: ", "-" * 100) - print(prompt) - print("Completion: ", "-" * 100) - print(filter_code(completion, model)) - samples.append(sample) + prompts.append(prompt) + task_ids.append(task_id) progress_bar.update(1) progress_bar.close() + if temperature > 0: + completions = model.run(prompts, temperature=temperature, do_sample=True) + else: + completions = model.run(prompts) + + for i, (prompt, completion, task_id) in enumerate( + zip(prompts, completions, task_ids) + ): + completion = fix_indents(completion) + sample = dict(task_id=task_id, completion=filter_code(completion, model)) + if i == 0: + print("Prompt: ", "-" * 100) + print(prompt) + print("Completion: ", "-" * 100) + print(filter_code(completion, model)) + samples.append(sample) + model_name = model.model_path.replace("/", "_") pred_filename = f"humaneval_{model_name}_predictions.jsonl" write_jsonl(pred_filename, samples) diff --git a/modeling.py b/modeling.py index 93c15ea..0973101 100644 --- a/modeling.py +++ b/modeling.py @@ -276,6 +276,9 @@ def run(self, prompt: str, **kwargs) -> str: self.load() all_outputs = [] + do_sample = kwargs.pop("do_sample", self.do_sample) + max_output_length = kwargs.pop("max_output_length", self.max_output_length) + for i in range(0, len(prompt), self.batch_size): batch_prompt = prompt[i : i + self.batch_size] inputs = self.tokenizer(batch_prompt, return_tensors="pt", padding=True).to( @@ -286,9 +289,9 @@ def run(self, prompt: str, **kwargs) -> str: outputs = self.model.generate( **inputs, - max_new_tokens=self.max_output_length, + max_new_tokens=max_output_length, pad_token_id=self.tokenizer.eos_token_id, # Avoid pad token warning - do_sample=self.do_sample, + do_sample=do_sample, **kwargs, ) batch_size, length = inputs.input_ids.shape From 121f2ea590c60a189a7603aff5c5d3a323e2faf3 Mon Sep 17 00:00:00 2001 From: Jingwei Yi Date: Fri, 27 Oct 2023 08:08:32 +0000 Subject: [PATCH 7/7] add dtype in vllmMOdel --- modeling.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/modeling.py b/modeling.py index 0973101..8f4c703 100644 --- a/modeling.py +++ b/modeling.py @@ -138,18 +138,15 @@ class vllmModel(EvalModel): load_8bit: bool = False temperature: float = 0.0 tensor_parallel_size: int = 1 + dtype: str = "float16" def load(self): if self.model is None: - args = {} - if self.load_8bit: - args.update(device_map="auto", load_in_8bit=True) - else: - args.update(device_map="auto", torch_dtype=torch.float16) self.model = LLM( model=self.model_path, trust_remote_code=self.trust_remote_code, tensor_parallel_size=self.tensor_parallel_size, + dtype=self.dtype, ) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(