Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,6 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
DEFAULT_SFT_MESSAGES_KEY = "messages"
GROUND_TRUTHS_KEY = "ground_truth"
VERIFIER_SOURCE_KEY = "dataset"
RAW_PROMPT_KEY = "prompt"


@dataclass
Expand Down Expand Up @@ -815,14 +814,8 @@ def tokenizer(self):
ATTENTION_MASK_KEY = "attention_mask"
LABELS_KEY = "labels"
DATASET_ORIGIN_KEY = "dataset_source" # just 'dataset' clashes with RLVR stuff (see VERIFIER_SOURCE_KEY)
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, RAW_PROMPT_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [
INPUT_IDS_KEY,
ATTENTION_MASK_KEY,
LABELS_KEY,
DATASET_ORIGIN_KEY,
RAW_PROMPT_KEY,
]
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY]


def remove_dataset_source_field(dataset: Dataset) -> Dataset:
Expand Down Expand Up @@ -1192,8 +1185,6 @@ def rlvr_tokenize_v1(
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt)
return row


Expand Down Expand Up @@ -1221,10 +1212,6 @@ def rlvr_tokenize_v2(
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
# row[DEFAULT_SFT_MESSAGES_KEY] = prompt
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt)
# some basic transformations:
# if ground truths is a string, make it a list
if isinstance(row[ground_truths_key], str):
Expand Down Expand Up @@ -1686,6 +1673,7 @@ def get_cached_dataset_tulu_with_statistics(
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)

dataset_config = DatasetConfig(
dataset_name=dataset_name,
dataset_split=dataset_mixer_list_splits[i],
Expand Down
11 changes: 2 additions & 9 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@

logger = logger_utils.setup_logger(__name__)

# remove excessive logging from liteLLM
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("litellm").setLevel(logging.ERROR)
logging.getLogger("litellm.cost_calculator").setLevel(logging.CRITICAL)
logging.getLogger("litellm._client").setLevel(logging.CRITICAL)
logging.getLogger("cost_calculator").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)


@dataclass
class VerifierConfig:
Expand Down Expand Up @@ -671,7 +663,8 @@ async def async_call(
for attempt in range(max_retries):
# judges the quality of a response
try:
messages = build_messages(prompt)
system_prompt = "Do not generate text between the <think> and </think> tags." # "You are a concise assistant who gives very short explanations before giving a quality score."
messages = build_messages(prompt, system_prompt)

# Faeze: check if the request would exceed context window
# Import the context window checker
Expand Down
66 changes: 24 additions & 42 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
from open_instruct.dataset_transformation import (
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
RAW_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
Expand Down Expand Up @@ -120,6 +119,7 @@
calibrate_checkpoint_state_dir,
clean_last_n_checkpoints_deepspeed,
download_latest_checkpoint_from_gs,
extract_user_query,
get_beaker_whoami,
get_eval_ds_config,
get_optimizer_grouped_parameters,
Expand Down Expand Up @@ -487,7 +487,6 @@ def next_batch(dataset_indices: List[int], dataset: datasets.Dataset) -> Batch:
queries=data_next[INPUT_IDS_PROMPT_KEY],
ground_truths=data_next[GROUND_TRUTHS_KEY],
datasets=data_next[VERIFIER_SOURCE_KEY],
raw_queries=data_next[RAW_PROMPT_KEY],
indices=dataset_indices,
)

Expand Down Expand Up @@ -1280,63 +1279,45 @@ def __init__(self):
self._map = {} # dataset_idx -> (query, ground_truth, dataset, count)
self._lock = threading.Lock()

def insert(self, dataset_idx, query, ground_truth, dataset, raw_query):
def insert(self, dataset_idx, query, ground_truth, dataset):
"""Insert or increment count for a dataset index."""
with self._lock:
if dataset_idx in self._map:
# Already exists - just increment count
existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[
dataset_idx
]
self._map[dataset_idx] = (
existing_query,
existing_ground_truth,
existing_dataset,
existing_raw_query,
count + 1,
)
existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx]
self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1)
else:
# New entry - count starts at 1
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1)
self._map[dataset_idx] = (query, ground_truth, dataset, 1)

def insert_many(self, dataset_indices, queries, ground_truths, datasets, raw_queries):
def insert_many(self, dataset_indices, queries, ground_truths, datasets):
"""Insert or increment count for multiple dataset indices at once."""
with self._lock:
for i, dataset_idx in enumerate(dataset_indices):
current_raw_query = raw_queries[i]

if dataset_idx in self._map:
# Already exists - just increment count
existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[
dataset_idx
]
self._map[dataset_idx] = (
existing_query,
existing_ground_truth,
existing_dataset,
existing_raw_query,
count + 1,
)
existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx]
self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1)
else:
# New entry - count starts at 1
self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], current_raw_query, 1)
self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], 1)

def pop(self, dataset_idx):
"""Retrieve data and decrement count. Removes entry when count reaches 0."""
with self._lock:
if dataset_idx not in self._map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")

query, ground_truth, dataset, raw_query, count = self._map[dataset_idx]
query, ground_truth, dataset, count = self._map[dataset_idx]

if count > 1:
# More results expected - just decrement
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1)
self._map[dataset_idx] = (query, ground_truth, dataset, count - 1)
else:
# Last result - remove entry
del self._map[dataset_idx]

return query, ground_truth, dataset, raw_query
return query, ground_truth, dataset

def __len__(self):
"""Return the number of entries in the map."""
Expand Down Expand Up @@ -1389,7 +1370,6 @@ def accumulate_inference_batches(
all_queries = []
all_ground_truths = []
all_datasets = []
all_raw_queries = []
for i in tqdm(
range(args.vllm_num_engines),
total=args.vllm_num_engines,
Expand Down Expand Up @@ -1421,20 +1401,17 @@ def accumulate_inference_batches(
batch_queries = []
batch_ground_truths = []
batch_datasets = []
batch_raw_queries = []

for dataset_idx in dataset_indices:
query, ground_truth, dataset, raw_query = pending_queries_map.pop(dataset_idx)
query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
batch_queries.append(query)
batch_ground_truths.append(ground_truth)
batch_datasets.append(dataset)
batch_raw_queries.append(raw_query)

results.append(result)
all_queries.extend(batch_queries)
all_ground_truths.extend(batch_ground_truths)
all_datasets.extend(batch_datasets)
all_raw_queries.extend(batch_raw_queries)

# Combine all results into a single GenerationResult
combined_responses = []
Expand Down Expand Up @@ -1496,7 +1473,6 @@ def accumulate_inference_batches(
queries=all_queries,
ground_truths=all_ground_truths,
datasets=all_datasets,
raw_queries=all_raw_queries,
indices=None, # Not meaningful for combined results
)
return combined_result, batch
Expand Down Expand Up @@ -1533,7 +1509,6 @@ def data_preparation_thread(
queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout),
ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout),
datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout),
raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout),
indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None,
)
good_outputs = [
Expand All @@ -1555,7 +1530,8 @@ def data_preparation_thread(

with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
decoded_queries = batch.raw_queries
decoded_queries = tokenizer.batch_decode(batch.queries, skip_special_tokens=True)
decoded_queries = [extract_user_query(query) for query in decoded_queries]
stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
result.finish_reasons
)
Expand Down Expand Up @@ -1871,14 +1847,21 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod

wandb_url = None
if args.with_tracking:
wandb.init(
run = wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
config=all_configs,
name=args.run_name,
save_code=True,
tags=[args.exp_name] + get_wandb_tags(),
)
run.define_metric("*", step_metric="episode")
run.define_metric("episode")
run.define_metric("training_step", step_metric="episode")
run.define_metric("val/*", step_metric="episode")
run.define_metric("objective/*", step_metric="episode")
run.define_metric("time/*", step_metric="episode")
run.define_metric("eval/*", step_metric="episode")
wandb_url = wandb.run.get_url()
maybe_update_beaker_description(wandb_url=wandb_url)

Expand Down Expand Up @@ -2071,7 +2054,7 @@ def split_and_insert_batch(

# Store prompts in the map using thread-safe insert_many
pending_queries_map.insert_many(
sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets, sub_batch.raw_queries
sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets
)

# Use PromptRequest for Ray queue with batch-specific dataset_index list
Expand Down Expand Up @@ -2283,7 +2266,6 @@ def one_training_step(
total_time = time.perf_counter() - start_time
metrics = {
"episode": episode,
"global_step": episode,
"training_step": training_step,
"val/num_total_tokens": num_total_tokens,
"epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset),
Expand Down
23 changes: 10 additions & 13 deletions open_instruct/judge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@
AI assistant to the user query displayed below.

Notes:
- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions.
- If there is a system prompt, ensure the AI answer prioritizes following it.
- Begin your evaluation by providing a short explanation.
- Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10.
- Please adhere to the following format.
1- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response.
2- Begin your evaluation by providing a short explanation.
3- Be as objective as possible. After providing your explanation, please rate the response on a scale of 1 to 10.

[Conversation History]
[Query]
{input}

[AI Answer]
[Response]
{output}

[Your judgement]
Respond in JSON format. {{"REASONING": "[...]", "SCORE": "<your-score>"}}"""
Respond in JSON format. {{"REASONING": "[...]", "SCORE": "<your-score>"}}
"""


general_quality_rubric_template = """
Expand Down Expand Up @@ -77,18 +76,16 @@
general_quality_ref_template = """
### Task Description
Please act as an impartial judge and evaluate the quality of the answer provided by an
AI assistant to the conversation history leading up to the answer displayed below.
Judge whether the provided answer is good by comparing it to the reference answer.
AI assistant to the user query displayed below. Judge whether the provided answer is good by comparing it to the reference answer.

Notes:
- Besides comparing to the reference answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions.
- Besides comparing to the referennce answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and appropriate level of detail of the response.
- Note that sometimes the reference answer is not the only answer. So any valid variation of the reference answer is also acceptable and can get a full score.
- If there is a system prompt, ensure the AI answer prioritizes following it.
- Begin your evaluation by providing a short explanation.
- Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10.
- Please adhere to the following format.

[Conversation History]
[Query]
{input}

[AI Answer]
Expand Down
4 changes: 0 additions & 4 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class Batch:
queries: List[List[int]]
ground_truths: List[List[int]]
datasets: List[str]
raw_queries: Optional[List[str]]
indices: Optional[List[int]]

def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
Expand All @@ -65,7 +64,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=self.queries[key],
ground_truths=self.ground_truths[key],
datasets=self.datasets[key],
raw_queries=self.raw_queries[key],
indices=self.indices[key] if self.indices else None,
)
elif isinstance(key, int):
Expand All @@ -74,7 +72,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=[self.queries[key]],
ground_truths=[self.ground_truths[key]],
datasets=[self.datasets[key]],
raw_queries=[self.raw_queries[key]],
indices=[self.indices[key]] if self.indices else None,
)
else:
Expand All @@ -83,7 +80,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=[self.queries[i] for i in key],
ground_truths=[self.ground_truths[i] for i in key],
datasets=[self.datasets[i] for i in key],
raw_queries=[self.raw_queries[i] for i in key],
indices=[self.indices[i] for i in key] if self.indices else None,
)

Expand Down
Loading