Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating get_statistics from tulu 1 for general use #196

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions open_instruct/get_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,39 @@
import pandas as pd
import tqdm
from datasets import load_dataset
from huggingface_hub import repo_exists
from transformers import AutoTokenizer


def get_statistics_for_messages_data(data_path):
# load dataset
dataset = load_dataset("json", data_files={"train": data_path})
def get_statistics_for_messages_data(
data_path,
dataset=None,
split="train",
messages_key="messages",
tokenizer="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/",
):
if dataset is None:
# load dataset
dataset = load_dataset("json", data_files={split: data_path})
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained(
"/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B", use_fast=False
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=False)
# get statistics
num_instances = len(dataset["train"])
num_of_turns = [len(instance["messages"]) for instance in dataset["train"]]
num_instances = len(dataset[split])

# remove any messages that have "role" == "system"
def remove_system_messages(example):
example[args.messages_key] = [message for message in example[args.messages_key] if message["role"] != "system"]
return example

dataset = dataset.map(remove_system_messages)

num_of_turns = [len(instance[args.messages_key]) for instance in dataset[split]]
user_prompt_lengths = []
assistant_response_lengths = []
instance_lengths = []
for instance in tqdm.tqdm(dataset["train"], desc="Processing instances"):
for instance in tqdm.tqdm(dataset[split], desc="Processing instances"):
instance_length = 0
for message in instance["messages"]:
for message in instance[args.messages_key]:
if message["role"] == "user":
user_prompt_lengths.append(
len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"])
Expand All @@ -51,7 +65,10 @@ def get_statistics_for_messages_data(data_path):
instance_lengths.append(instance_length)

top_100_longest_instances = np.argsort(instance_lengths)[-100:][::-1].tolist()
top_100_longest_instances = [dataset["train"][i]["id"] for i in top_100_longest_instances]
if "id" in dataset[split].features:
top_100_longest_instances = [dataset[split][i]["id"] for i in top_100_longest_instances]
else:
top_100_longest_instances = None

result = {
"num_instances": num_instances,
Expand Down Expand Up @@ -80,17 +97,24 @@ def get_statistics_for_messages_data(data_path):
return result


def get_statistics_for_prompt_completion_data(data_path):
# load dataset
dataset = load_dataset("json", data_files={"train": data_path})
prompts = [instance["prompt"] for instance in dataset["train"]]
completions = [instance["completion"] for instance in dataset["train"]]
def get_statistics_for_prompt_completion_data(
data_path,
dataset=None,
split="train",
response_key="completion",
tokenizer="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/",
):
if dataset is None:
# load dataset
dataset = load_dataset("json", data_files={split: data_path})
prompts = [instance["prompt"] for instance in dataset[split]]
completions = [instance[response_key] for instance in dataset[split]]
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B")
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
tokenized_prompts = tokenizer(prompts, truncation=False, add_special_tokens=False)
tokenized_completions = tokenizer(completions, truncation=False, add_special_tokens=False)
# get statistics
num_instances = len(dataset["train"])
num_instances = len(dataset[split])
prompt_lengths = [len(tokenized_prompts["input_ids"][i]) for i in range(num_instances)]
completion_lengths = [len(tokenized_completions["input_ids"][i]) for i in range(num_instances)]
prompt_completion_lengths = [prompt_lengths[i] + completion_lengths[i] for i in range(num_instances)]
Expand Down Expand Up @@ -123,14 +147,33 @@ def get_statistics_for_prompt_completion_data(data_path):
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--save_path", type=str, help="Path to save the statistics.")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--response_key", type=str, default="completion")
parser.add_argument("--messages_key", type=str, default="messages")
parser.add_argument("--tokenizer", type=str, default="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/")
args = parser.parse_args()

with open(args.data_path, "r") as f:
sample = json.loads(f.readline())
if "prompt" in sample:
statistics = get_statistics_for_prompt_completion_data(args.data_path)
elif "messages" in sample:
statistics = get_statistics_for_messages_data(args.data_path)
# Check if the data_path is a dataset id, only check if /
if "json" in args.data_path:
with open(args.data_path, "r") as f:
sample = json.loads(f.readline())
dataset = None

elif repo_exists(args.data_path, repo_type="dataset"):

dataset = load_dataset(args.data_path)
sample = dataset[args.split][0]
else:
raise ValueError("Invalid data path - the data path should be either a dataset id or a path to a json file.")

if args.messages_key in sample:
statistics = get_statistics_for_messages_data(
args.data_path, dataset=dataset, split=args.split, messages_key=args.messages_key, tokenizer=args.tokenizer
)
elif "prompt" in sample:
statistics = get_statistics_for_prompt_completion_data(
args.data_path, dataset=dataset, split=args.split, response_key=args.response_key, tokenizer=args.tokenizer
)
else:
raise ValueError("Invalid data format - the data should be either prompt completion data or messages data.")

Expand Down
Loading