diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index da86f102c..5bd3a3dec 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -763,7 +763,6 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): DEFAULT_SFT_MESSAGES_KEY = "messages" GROUND_TRUTHS_KEY = "ground_truth" VERIFIER_SOURCE_KEY = "dataset" -RAW_PROMPT_KEY = "prompt" @dataclass @@ -815,14 +814,8 @@ def tokenizer(self): ATTENTION_MASK_KEY = "attention_mask" LABELS_KEY = "labels" DATASET_ORIGIN_KEY = "dataset_source" # just 'dataset' clashes with RLVR stuff (see VERIFIER_SOURCE_KEY) -TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, RAW_PROMPT_KEY] -TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [ - INPUT_IDS_KEY, - ATTENTION_MASK_KEY, - LABELS_KEY, - DATASET_ORIGIN_KEY, - RAW_PROMPT_KEY, -] +TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY] +TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY] def remove_dataset_source_field(dataset: Dataset) -> Dataset: @@ -1192,8 +1185,6 @@ def rlvr_tokenize_v1( row[LABELS_KEY] = labels row[GROUND_TRUTHS_KEY] = row[ground_truths_key] row[VERIFIER_SOURCE_KEY] = row[verifier_source_key] - # concatenate all the previous messages as : \n : \n ... - row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt) return row @@ -1221,10 +1212,6 @@ def rlvr_tokenize_v2( row[LABELS_KEY] = labels row[GROUND_TRUTHS_KEY] = row[ground_truths_key] row[VERIFIER_SOURCE_KEY] = row[verifier_source_key] - # concatenate all the previous messages as : \n : \n ... - # row[DEFAULT_SFT_MESSAGES_KEY] = prompt - # concatenate all the previous messages as : \n : \n ... - row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt) # some basic transformations: # if ground truths is a string, make it a list if isinstance(row[ground_truths_key], str): @@ -1686,6 +1673,7 @@ def get_cached_dataset_tulu_with_statistics( frac_or_num_samples = float(frac_or_num_samples) else: frac_or_num_samples = int(frac_or_num_samples) + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split=dataset_mixer_list_splits[i], diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 2c3208878..59b0e5c4a 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -38,14 +38,6 @@ logger = logger_utils.setup_logger(__name__) -# remove excessive logging from liteLLM -logging.getLogger("LiteLLM").setLevel(logging.WARNING) -logging.getLogger("litellm").setLevel(logging.ERROR) -logging.getLogger("litellm.cost_calculator").setLevel(logging.CRITICAL) -logging.getLogger("litellm._client").setLevel(logging.CRITICAL) -logging.getLogger("cost_calculator").setLevel(logging.WARNING) -logging.getLogger("httpx").setLevel(logging.WARNING) - @dataclass class VerifierConfig: @@ -671,7 +663,8 @@ async def async_call( for attempt in range(max_retries): # judges the quality of a response try: - messages = build_messages(prompt) + system_prompt = "Do not generate text between the and tags." # "You are a concise assistant who gives very short explanations before giving a quality score." + messages = build_messages(prompt, system_prompt) # Faeze: check if the request would exceed context window # Import the context window checker diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index b5075c0e0..0a08e2fb4 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -87,7 +87,6 @@ from open_instruct.dataset_transformation import ( GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, - RAW_PROMPT_KEY, VERIFIER_SOURCE_KEY, TokenizerConfig, get_cached_dataset_tulu, @@ -120,6 +119,7 @@ calibrate_checkpoint_state_dir, clean_last_n_checkpoints_deepspeed, download_latest_checkpoint_from_gs, + extract_user_query, get_beaker_whoami, get_eval_ds_config, get_optimizer_grouped_parameters, @@ -487,7 +487,6 @@ def next_batch(dataset_indices: List[int], dataset: datasets.Dataset) -> Batch: queries=data_next[INPUT_IDS_PROMPT_KEY], ground_truths=data_next[GROUND_TRUTHS_KEY], datasets=data_next[VERIFIER_SOURCE_KEY], - raw_queries=data_next[RAW_PROMPT_KEY], indices=dataset_indices, ) @@ -1280,46 +1279,28 @@ def __init__(self): self._map = {} # dataset_idx -> (query, ground_truth, dataset, count) self._lock = threading.Lock() - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + def insert(self, dataset_idx, query, ground_truth, dataset): """Insert or increment count for a dataset index.""" with self._lock: if dataset_idx in self._map: # Already exists - just increment count - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) + existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx] + self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1) else: # New entry - count starts at 1 - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + self._map[dataset_idx] = (query, ground_truth, dataset, 1) - def insert_many(self, dataset_indices, queries, ground_truths, datasets, raw_queries): + def insert_many(self, dataset_indices, queries, ground_truths, datasets): """Insert or increment count for multiple dataset indices at once.""" with self._lock: for i, dataset_idx in enumerate(dataset_indices): - current_raw_query = raw_queries[i] - if dataset_idx in self._map: # Already exists - just increment count - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) + existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx] + self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1) else: # New entry - count starts at 1 - self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], current_raw_query, 1) + self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], 1) def pop(self, dataset_idx): """Retrieve data and decrement count. Removes entry when count reaches 0.""" @@ -1327,16 +1308,16 @@ def pop(self, dataset_idx): if dataset_idx not in self._map: raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + query, ground_truth, dataset, count = self._map[dataset_idx] if count > 1: # More results expected - just decrement - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + self._map[dataset_idx] = (query, ground_truth, dataset, count - 1) else: # Last result - remove entry del self._map[dataset_idx] - return query, ground_truth, dataset, raw_query + return query, ground_truth, dataset def __len__(self): """Return the number of entries in the map.""" @@ -1389,7 +1370,6 @@ def accumulate_inference_batches( all_queries = [] all_ground_truths = [] all_datasets = [] - all_raw_queries = [] for i in tqdm( range(args.vllm_num_engines), total=args.vllm_num_engines, @@ -1421,20 +1401,17 @@ def accumulate_inference_batches( batch_queries = [] batch_ground_truths = [] batch_datasets = [] - batch_raw_queries = [] for dataset_idx in dataset_indices: - query, ground_truth, dataset, raw_query = pending_queries_map.pop(dataset_idx) + query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) batch_queries.append(query) batch_ground_truths.append(ground_truth) batch_datasets.append(dataset) - batch_raw_queries.append(raw_query) results.append(result) all_queries.extend(batch_queries) all_ground_truths.extend(batch_ground_truths) all_datasets.extend(batch_datasets) - all_raw_queries.extend(batch_raw_queries) # Combine all results into a single GenerationResult combined_responses = [] @@ -1496,7 +1473,6 @@ def accumulate_inference_batches( queries=all_queries, ground_truths=all_ground_truths, datasets=all_datasets, - raw_queries=all_raw_queries, indices=None, # Not meaningful for combined results ) return combined_result, batch @@ -1533,7 +1509,6 @@ def data_preparation_thread( queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout), ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout), datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout), - raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout), indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None, ) good_outputs = [ @@ -1555,7 +1530,8 @@ def data_preparation_thread( with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - decoded_queries = batch.raw_queries + decoded_queries = tokenizer.batch_decode(batch.queries, skip_special_tokens=True) + decoded_queries = [extract_user_query(query) for query in decoded_queries] stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( result.finish_reasons ) @@ -1871,7 +1847,7 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod wandb_url = None if args.with_tracking: - wandb.init( + run = wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, config=all_configs, @@ -1879,6 +1855,13 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod save_code=True, tags=[args.exp_name] + get_wandb_tags(), ) + run.define_metric("*", step_metric="episode") + run.define_metric("episode") + run.define_metric("training_step", step_metric="episode") + run.define_metric("val/*", step_metric="episode") + run.define_metric("objective/*", step_metric="episode") + run.define_metric("time/*", step_metric="episode") + run.define_metric("eval/*", step_metric="episode") wandb_url = wandb.run.get_url() maybe_update_beaker_description(wandb_url=wandb_url) @@ -2071,7 +2054,7 @@ def split_and_insert_batch( # Store prompts in the map using thread-safe insert_many pending_queries_map.insert_many( - sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets, sub_batch.raw_queries + sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets ) # Use PromptRequest for Ray queue with batch-specific dataset_index list @@ -2283,7 +2266,6 @@ def one_training_step( total_time = time.perf_counter() - start_time metrics = { "episode": episode, - "global_step": episode, "training_step": training_step, "val/num_total_tokens": num_total_tokens, "epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset), diff --git a/open_instruct/judge_utils.py b/open_instruct/judge_utils.py index 1cc47c7a1..2e0140016 100644 --- a/open_instruct/judge_utils.py +++ b/open_instruct/judge_utils.py @@ -31,20 +31,19 @@ AI assistant to the user query displayed below. Notes: -- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions. -- If there is a system prompt, ensure the AI answer prioritizes following it. -- Begin your evaluation by providing a short explanation. -- Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10. -- Please adhere to the following format. +1- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. +2- Begin your evaluation by providing a short explanation. +3- Be as objective as possible. After providing your explanation, please rate the response on a scale of 1 to 10. -[Conversation History] +[Query] {input} -[AI Answer] +[Response] {output} [Your judgement] -Respond in JSON format. {{"REASONING": "[...]", "SCORE": ""}}""" +Respond in JSON format. {{"REASONING": "[...]", "SCORE": ""}} +""" general_quality_rubric_template = """ @@ -77,18 +76,16 @@ general_quality_ref_template = """ ### Task Description Please act as an impartial judge and evaluate the quality of the answer provided by an -AI assistant to the conversation history leading up to the answer displayed below. -Judge whether the provided answer is good by comparing it to the reference answer. +AI assistant to the user query displayed below. Judge whether the provided answer is good by comparing it to the reference answer. Notes: -- Besides comparing to the reference answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions. +- Besides comparing to the referennce answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and appropriate level of detail of the response. - Note that sometimes the reference answer is not the only answer. So any valid variation of the reference answer is also acceptable and can get a full score. -- If there is a system prompt, ensure the AI answer prioritizes following it. - Begin your evaluation by providing a short explanation. - Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10. - Please adhere to the following format. -[Conversation History] +[Query] {input} [AI Answer] diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 93474d7fc..225f6e7d5 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -54,7 +54,6 @@ class Batch: queries: List[List[int]] ground_truths: List[List[int]] datasets: List[str] - raw_queries: Optional[List[str]] indices: Optional[List[int]] def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch": @@ -65,7 +64,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch": queries=self.queries[key], ground_truths=self.ground_truths[key], datasets=self.datasets[key], - raw_queries=self.raw_queries[key], indices=self.indices[key] if self.indices else None, ) elif isinstance(key, int): @@ -74,7 +72,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch": queries=[self.queries[key]], ground_truths=[self.ground_truths[key]], datasets=[self.datasets[key]], - raw_queries=[self.raw_queries[key]], indices=[self.indices[key]] if self.indices else None, ) else: @@ -83,7 +80,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch": queries=[self.queries[i] for i in key], ground_truths=[self.ground_truths[i] for i in key], datasets=[self.datasets[i] for i in key], - raw_queries=[self.raw_queries[i] for i in key], indices=[self.indices[i] for i in key] if self.indices else None, ) diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index c66d2c3e7..be9cbd417 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -124,8 +124,7 @@ def create_test_data(self, num_prompts, prefix="", start_idx=0): queries = [f"{prefix}query_{i}" for i in indices] ground_truths = [f"{prefix}truth_{i}" for i in indices] datasets = [f"{prefix}dataset_{i}" for i in indices] - raw_queries = [f"{prefix}rawquery_{i}" for i in indices] - return queries, ground_truths, datasets, raw_queries, indices + return queries, ground_truths, datasets, indices def create_mock_args(self, num_engines=4, num_samples=1): """Create mock args object.""" @@ -154,9 +153,7 @@ def create_mock_result(self, dataset_indices, training_step, num_samples_per_pro dataset_index=dataset_indices, ) - def setup_and_split_batch( - self, queries, ground_truths, datasets, raw_queries, indices, num_engines, training_step=1 - ): + def setup_and_split_batch(self, queries, ground_truths, datasets, indices, num_engines, training_step=1): """Setup queues and split batch - common pattern.""" param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2) inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) @@ -165,9 +162,7 @@ def setup_and_split_batch( # Track queues for cleanup self._ray_queues.extend([param_prompt_Q, inference_results_Q]) - batch = model_utils.Batch( - queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices - ) + batch = model_utils.Batch(queries=queries, ground_truths=ground_truths, datasets=datasets, indices=indices) # Create a mock generation_config for testing from unittest.mock import MagicMock @@ -264,13 +259,13 @@ def test_vllm_queue_system_single_prompt(self): def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, num_unique_prompts_rollout: int): """Test batch splitting and accumulation with various engine configurations.""" # Create test data - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices = self.create_test_data( + queries_next, ground_truths_next, datasets_next, dataset_indices = self.create_test_data( num_unique_prompts_rollout ) # Setup and split batch param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_split_batch( - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines + queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) # Verify that we have individual prompts in the map (not batches) @@ -306,7 +301,7 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, batch_ground_truths = [] batch_datasets = [] for idx in dataset_indices: - q, gt, d, _raw_q = pending_queries_map.pop(idx) + q, gt, d = pending_queries_map.pop(idx) batch_queries.append(q) batch_ground_truths.append(gt) batch_datasets.append(d) @@ -354,13 +349,13 @@ def test_dataset_index_preservation_through_pipeline(self): num_unique_prompts_rollout = 32 # Create test data - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices = self.create_test_data( + queries_next, ground_truths_next, datasets_next, dataset_indices = self.create_test_data( num_unique_prompts_rollout ) # Setup and split batch param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_split_batch( - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines + queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) # Simulate vLLM processing @@ -381,7 +376,7 @@ def test_dataset_index_preservation_through_pipeline(self): dataset_indices = result.dataset_index for idx in dataset_indices: - q, gt, d, _raw_q = pending_queries_map.pop(idx) + q, gt, d = pending_queries_map.pop(idx) combined_queries.append(q) combined_ground_truths.append(gt) combined_datasets.append(d) @@ -398,13 +393,13 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe num_unique_prompts_rollout = 16 # Create test data - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices = self.create_test_data( + queries_next, ground_truths_next, datasets_next, dataset_indices = self.create_test_data( num_unique_prompts_rollout ) # Setup and split batch param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_split_batch( - queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines + queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) # Simulate vLLM processing with multiple samples @@ -429,7 +424,7 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe batch_ground_truths = [] batch_datasets = [] for idx in dataset_indices: - q, gt, d, _raw_q = pending_queries_map.pop(idx) + q, gt, d = pending_queries_map.pop(idx) batch_queries.append(q) batch_ground_truths.append(gt) batch_datasets.append(d) @@ -513,11 +508,11 @@ def test_out_of_order_processing(self): num_samples_per_prompt = 4 # Create test data - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) + queries, ground_truths, datasets, indices = self.create_test_data(num_prompts) # Setup and split batch param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_split_batch( - queries, ground_truths, datasets, raw_queries, indices, num_engines + queries, ground_truths, datasets, indices, num_engines ) # Get all requests and process in reverse order @@ -566,11 +561,7 @@ def add_and_remove_entries(thread_id): # Add entries for i in range(start_idx, start_idx + entries_per_thread): pending_queries_map.insert( - i, - f"query_{thread_id}_{i}", - f"truth_{thread_id}_{i}", - f"dataset_{thread_id}_{i}", - f"query_{thread_id}_{i}", + i, f"query_{thread_id}_{i}", f"truth_{thread_id}_{i}", f"dataset_{thread_id}_{i}" ) time.sleep(0.0001) @@ -611,7 +602,7 @@ def test_accumulate_waits_for_all_engines(self): # Add entries to map for i in range(num_prompts): - pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}") + pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}") # Add results from only 3 engines (missing one) for engine_id in range(3): @@ -665,16 +656,14 @@ def test_more_engines_than_queries(self): num_engines = 8 num_queries = 4 - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) + queries, ground_truths, datasets, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) - batch = model_utils.Batch( - queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices - ) + batch = model_utils.Batch(queries=queries, ground_truths=ground_truths, datasets=datasets, indices=indices) # Create a mock generation_config from unittest.mock import MagicMock @@ -717,16 +706,14 @@ def test_uneven_distribution_no_empty_batches(self): num_engines = 3 num_queries = 7 # 7/3 = ceil(2.33) = 3, so distribution should be [3, 3, 1] - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) + queries, ground_truths, datasets, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) - batch = model_utils.Batch( - queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices - ) + batch = model_utils.Batch(queries=queries, ground_truths=ground_truths, datasets=datasets, indices=indices) # Create a mock generation_config from unittest.mock import MagicMock @@ -773,7 +760,7 @@ def test_streaming_accumulation_basic(self): num_prompts = 8 # Create test data - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) + queries, ground_truths, datasets, indices = self.create_test_data(num_prompts) # Create queues and maps inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) @@ -784,7 +771,7 @@ def test_streaming_accumulation_basic(self): # Insert data into pending_queries_map for i in range(num_prompts): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i]) # Create mock results with batch indices batch_size = num_prompts // num_engines @@ -811,7 +798,7 @@ def test_streaming_accumulation_basic(self): batch_ground_truths = [] batch_datasets = [] for idx in dataset_indices: - q, gt, d, _raw_q = pending_queries_map.pop(idx) + q, gt, d = pending_queries_map.pop(idx) batch_queries.append(q) batch_ground_truths.append(gt) batch_datasets.append(d) @@ -838,7 +825,7 @@ def test_streaming_with_multiple_samples(self): num_samples = 3 # Create test data - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) + queries, ground_truths, datasets, indices = self.create_test_data(num_prompts) # Create queues and maps inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) @@ -850,7 +837,7 @@ def test_streaming_with_multiple_samples(self): # Insert data with reference counting for multiple samples for i in range(num_prompts): for _ in range(num_samples): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i]) # Create results with multiple samples per prompt batch_size = num_prompts // num_engines diff --git a/open_instruct/utils.py b/open_instruct/utils.py index d037dd0f4..2d57e6e2c 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1495,6 +1495,18 @@ def backload_to_gpu(self, model, non_blocking=True): def extract_user_query(conversation: str, chat_template_name: str = None) -> str: + # only works for tulu_thinker_r1_style + # pattern = r"\n\n<\|user\|>\n(.*?)\n<\|assistant\|>\n" + + # works for tulu_thinker_r1_style and tulu_thinker + # pattern = r"<\|user\|>\n(.*?)\n<\|assistant\|>\n" + + # match = re.search(pattern, conversation, re.DOTALL) + # # Return the captured group if found, else return None + # return match.group(1).strip() if match else None + + # works for olmo too: + # TODO: implement a better logic to get queries before creating the chat template pattern = re.compile( r"(?:" r"<\|user\|\>\n(?P.*?)\n<\|assistant\|\>\n" # template 0 (your original) @@ -1504,13 +1516,14 @@ def extract_user_query(conversation: str, chat_template_name: str = None) -> str r")", re.DOTALL, ) + # Get the last user query matched (most recent user turn before assistant ) matches = list(pattern.finditer(conversation)) if matches: m = matches[-1] user_query = (m.group("simple") or m.group("im")).strip() else: - user_query = conversation + user_query = None return user_query