Skip to content

Commit

Permalink
Merge pull request #94 from oindrillac/messages
Browse files Browse the repository at this point in the history
converts dataset format messages required for training
  • Loading branch information
russellb authored Jul 11, 2024
2 parents 7ef628f + b0fcf32 commit 7bf1563
Showing 1 changed file with 60 additions and 15 deletions.
75 changes: 60 additions & 15 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,69 @@ def _get_response(logger, synth_example):
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


def _gen_train_data(logger, output_datasets, output_file_train):
def _convert_to_messages(sample):
"""
Convert a sample dictionary to contain 'messages' and 'metadata' columns required for training.
"""
# Create user query message
user_query = sample["inputs"]
# TODO: in the future we can remove the combinecolumnsblock and combine them here for simplicity
# if "context" in sample:
# user_query = f"{sample['context']}\n\n{sample['inputs']}"

sample["messages"] = [
{"content": user_query, "role": "user"},
{"content": sample["targets"], "role": "assistant"},
]
metadata = {
key: value
for key, value in sample.items()
if key not in ["messages", "inputs", "targets"]
}
sample["metadata"] = json.dumps(metadata)

# keeping required keys for messages training format
sample = {"messages": sample["messages"], "metadata": sample["metadata"]}

return sample


def _gen_train_data(
logger, machine_instruction_data, output_file_train, output_file_messages
):
train_data = []
for output_dataset in output_datasets:
messages_data = []

for output_dataset in machine_instruction_data:
for synth_example in output_dataset:
logger.debug(synth_example)
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": _SYS_PROMPT,
"user": _unescape(user),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)
assistant = _unescape(_get_response(logger, synth_example))
train_entry = {
"system": _SYS_PROMPT,
"user": _unescape(user),
"assistant": assistant,
}
train_data.append(train_entry)
sample = {
"inputs": _unescape(user),
"targets": assistant,
"system": _SYS_PROMPT,
}
messages_data.append(_convert_to_messages(sample))

with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")

with open(output_file_messages, "w", encoding="utf-8") as outfile:
for entry in messages_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def _gen_test_data(
leaf_nodes,
Expand Down Expand Up @@ -219,16 +261,17 @@ def generate_data(

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
output_file_generated = f"generated_{name}_{date_suffix}.json"
output_file_messages = f"messages_{name}_{date_suffix}.json"
output_file_test = f"test_{name}_{date_suffix}.jsonl"
# train data in messages format that will be mixed and split up into train test eventually
output_file_train = f"train_{name}_{date_suffix}.jsonl"

_gen_test_data(
leaf_nodes,
os.path.join(output_dir, output_file_test),
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_generated)}")
logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")

orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd)
cert = tuple(item for item in orig_cert if item)
Expand Down Expand Up @@ -289,16 +332,18 @@ def generate_data(
if generated_data is None:
generated_data = []

_gen_train_data(logger, generated_data, os.path.join(output_dir, output_file_train))
_gen_train_data(
logger,
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
)

# TODO
# This is for backwards compatibility. The file existing previously, so we'll keep it for now.
# I believe the github bot assumes it is present for presenting generated data to a taxonomy
# reviewer or contributor. Otherwise, I don't see a consumer of it in this repo or the
# `ilab` CLI.
_gen_train_data(
logger, generated_data, os.path.join(output_dir, output_file_generated)
)

generate_duration = time.time() - generate_start
logger.info(f"Generation took {generate_duration:.2f}s")

0 comments on commit 7bf1563

Please sign in to comment.