From 279ed957e702fca171f920cca81b15d3ef158580 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 6 Sep 2024 15:32:38 +0000 Subject: [PATCH 1/4] Fix generation --- .../rejection_sampling/generation.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/open_instruct/rejection_sampling/generation.py b/open_instruct/rejection_sampling/generation.py index 671e753bf..68087c4b8 100644 --- a/open_instruct/rejection_sampling/generation.py +++ b/open_instruct/rejection_sampling/generation.py @@ -70,10 +70,9 @@ class GenerationArgs: class DatasetArgs: dataset_name: str = None dataset_text_field: str = "prompt" - dataset_train_split: str = "train" - dataset_test_split: str = "validation" + split: str = "train" dataset_start_idx: int = 0 - dataset_end_idx: Optional[int] = 100 + dataset_end_idx: Optional[int] = None sanity_check: bool = False sanity_check_size: int = 100 @@ -100,10 +99,11 @@ def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids: revision=revision, tokenizer_revision=revision, tensor_parallel_size=gen_args.tensor_parallel_size, + max_model_len=gen_args.response_length, ) # filter out prompts which are beyond the model's max token length - max_model_len = llm.llm_engine.scheduler_config.max_model_len + max_model_len = gen_args.response_length prompt_token_ids_len = len(prompt_token_ids) prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len] if len(prompt_token_ids) != prompt_token_ids_len: @@ -146,14 +146,12 @@ def format_conversation(messages: list) -> str: def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs): - ds = load_dataset(dataset_args.dataset_name) + ds = load_dataset(dataset_args.dataset_name, split=dataset_args.split) if dataset_args.sanity_check: - for key in ds: - ds[key] = ds[key].select(range(min(dataset_args.sanity_check_size, len(ds[key])))) + ds = ds.select(range(min(dataset_args.sanity_check_size, len(ds)))) if dataset_args.dataset_end_idx is None: - dataset_args.dataset_end_idx = len(ds[dataset_args.dataset_train_split]) - for key in ds: - ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx)) + dataset_args.dataset_end_idx = len(ds) + ds = ds.select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx)) pprint([dataset_args, args, gen_args]) if "gpt-3.5" in args.model_name_or_path or "gpt-4" in args.model_name_or_path: @@ -161,7 +159,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs): lambda x: {"prompt": format_conversation(x["messages"][:-1])}, num_proc=NUM_CPUS_FOR_DATASET_MAP, ) - messages = ds[dataset_args.dataset_train_split]["prompt"] + messages = ds["prompt"] responses = asyncio.run(generate_with_openai(args.model_name_or_path, messages, args, gen_args)) outputs = [{"outputs": [{"text": response} for response in responses]}] @@ -172,7 +170,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs): lambda x: {"prompt_token_ids": tokenizer.apply_chat_template(x["messages"][:-1])}, num_proc=NUM_CPUS_FOR_DATASET_MAP, ) - prompt_token_ids = ds[dataset_args.dataset_train_split]["prompt_token_ids"] + prompt_token_ids = ds["prompt_token_ids"] outputs = generate_with_vllm(args.model_name_or_path, args.revision, prompt_token_ids, gen_args) # Assuming we generate n=3 completions per prompt; the outputs will look like: @@ -185,7 +183,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs): # ... table = defaultdict(list) num_prompt_with_identical_completions = 0 - for output, messages in zip(outputs, ds[dataset_args.dataset_train_split]["messages"]): + for output, messages in zip(outputs, ds["messages"]): # if the model completions are exactly the same across all completions per prompt, we can skip this if len(set(tuple(item["text"]) for item in output["outputs"])) == 1: num_prompt_with_identical_completions += 1 From 6d56bf14d01c2de184f5a2cacf197a538f6e9ebb Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 6 Sep 2024 15:37:43 +0000 Subject: [PATCH 2/4] quick fix --- open_instruct/rejection_sampling/generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/rejection_sampling/generation.py b/open_instruct/rejection_sampling/generation.py index 68087c4b8..3f07e1233 100644 --- a/open_instruct/rejection_sampling/generation.py +++ b/open_instruct/rejection_sampling/generation.py @@ -94,18 +94,18 @@ async def generate_with_openai(model_name: str, data_list: list, args: Args, gen def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids: List[int], gen_args: GenerationArgs): + max_context_length = gen_args.response_length - 1 llm = LLM( model=model_name_or_path, revision=revision, tokenizer_revision=revision, tensor_parallel_size=gen_args.tensor_parallel_size, - max_model_len=gen_args.response_length, + max_model_len=max_context_length, ) # filter out prompts which are beyond the model's max token length - max_model_len = gen_args.response_length prompt_token_ids_len = len(prompt_token_ids) - prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len] + prompt_token_ids = [item for item in prompt_token_ids if len(item) <= max_context_length] if len(prompt_token_ids) != prompt_token_ids_len: print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length") From c600a3c407800e663b091dd0f07b6afd6747ab42 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 6 Sep 2024 15:45:28 +0000 Subject: [PATCH 3/4] quick fix --- open_instruct/rejection_sampling/generation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/open_instruct/rejection_sampling/generation.py b/open_instruct/rejection_sampling/generation.py index 3f07e1233..72c8db5d9 100644 --- a/open_instruct/rejection_sampling/generation.py +++ b/open_instruct/rejection_sampling/generation.py @@ -61,9 +61,13 @@ class Args: class GenerationArgs: num_completions: int = 3 temperature: float = 0.8 + max_model_length: int = 4096 response_length: int = 2048 top_p: float = 0.9 tensor_parallel_size: int = 1 + + def __post_init__(self): + assert self.response_length <= self.max_model_length, "response_length should be less than or equal to max_model_length" @dataclass @@ -94,20 +98,19 @@ async def generate_with_openai(model_name: str, data_list: list, args: Args, gen def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids: List[int], gen_args: GenerationArgs): - max_context_length = gen_args.response_length - 1 llm = LLM( model=model_name_or_path, revision=revision, tokenizer_revision=revision, tensor_parallel_size=gen_args.tensor_parallel_size, - max_model_len=max_context_length, + max_model_len=gen_args.max_model_length, ) # filter out prompts which are beyond the model's max token length prompt_token_ids_len = len(prompt_token_ids) - prompt_token_ids = [item for item in prompt_token_ids if len(item) <= max_context_length] + prompt_token_ids = [item for item in prompt_token_ids if len(item) < gen_args.max_model_length] if len(prompt_token_ids) != prompt_token_ids_len: - print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length") + print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max context length") outputs = llm.generate( prompt_token_ids=prompt_token_ids, From cbb6792d97aea4607faf3da017394ade7bfa88b4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 18 Sep 2024 19:19:49 +0000 Subject: [PATCH 4/4] push --- open_instruct/rejection_sampling/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/rejection_sampling/generation.py b/open_instruct/rejection_sampling/generation.py index 72c8db5d9..6ec7d0633 100644 --- a/open_instruct/rejection_sampling/generation.py +++ b/open_instruct/rejection_sampling/generation.py @@ -188,7 +188,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs): num_prompt_with_identical_completions = 0 for output, messages in zip(outputs, ds["messages"]): # if the model completions are exactly the same across all completions per prompt, we can skip this - if len(set(tuple(item["text"]) for item in output["outputs"])) == 1: + if len(set(tuple(item["text"]) for item in output["outputs"])) and output["outputs"][0] == messages[-1]["content"]: num_prompt_with_identical_completions += 1 continue