Skip to content

Commit

Permalink
chore: perform formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 21, 2024
1 parent b3d4a09 commit 8955a22
Show file tree
Hide file tree
Showing 21 changed files with 678 additions and 347 deletions.
2 changes: 2 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os, sys

sys.path = [os.path.dirname(os.path.abspath(__file__))] + sys.path

if not eval(os.environ.get("LOUD_BACKEND", "0")):
os.environ["WANDB_DISABLED"] = "true"

import logging

logging.basicConfig(level=logging.ERROR)

from benchmark.framework import JudgeBase, ExamineeBase
Expand Down
14 changes: 8 additions & 6 deletions build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def build_gutenberg():
dir = f"{root}/dataset/raw_downloads/Gutenberg/"
gtb_gd.get_data_gutenberg(dir)
gtb_gm.gather_meta(
os.path.join(dir, "data/raw"), f"{root}/dataset/raw_downloads/Gutenberg_records.txt"
os.path.join(dir, "data/raw"),
f"{root}/dataset/raw_downloads/Gutenberg_records.txt",
)
print("======= FINISHED BUILDING GUTENBERG DATASET =======\n\n\n")

Expand Down Expand Up @@ -107,9 +108,7 @@ def build_pile_of_law():
)

# Make llm-cleansed version the official version ("dataset_text_sequence"), and move the other two versions into dataset/raw_downloads
path = (
f"{root}/dataset/raw_downloads/dataset_text_sequence_versions/{timestamp}/"
)
path = f"{root}/dataset/raw_downloads/dataset_text_sequence_versions/{timestamp}/"
os.makedirs(path)

print(f"Moving pre-cleansing version to backup folder...")
Expand Down Expand Up @@ -154,7 +153,9 @@ def build_pile_of_law():
sub_datasets = [
f
for f in os.listdir(f"{root}/dataset/dataset_text_sequence/")
if os.path.isdir(os.path.join(f"{root}/dataset/dataset_text_sequence/", f))
if os.path.isdir(
os.path.join(f"{root}/dataset/dataset_text_sequence/", f)
)
]
for sub in sub_datasets:
# Remove if size < 10MB AND century number < 13
Expand All @@ -169,7 +170,8 @@ def build_pile_of_law():
os.system(f"mv ./dataset/dataset_text_sequence/{sub} {path}")

hislm.run_training(
f"{root}/dataset/dataset_text_sequence/", f"{root}/dataset/dataset_model_sequence/"
f"{root}/dataset/dataset_text_sequence/",
f"{root}/dataset/dataset_model_sequence/",
)
print("Finished model training. Exiting.")

Expand Down
26 changes: 15 additions & 11 deletions examples/abstractions/finetuning_datamanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_instruct_finetuned=True,
)


def continue_pretrain():
# ============== Continue pretraining from Gemma 2B ==============
global gemma2b_c4
Expand All @@ -22,6 +23,7 @@ def continue_pretrain():
)
print(gemma2b_c4.is_instruct_finetuned) # False


def supervised_finetune():
# ============== Then do SFT using alpaca data ==============
global gemma2b_c4_alpaca
Expand All @@ -34,7 +36,7 @@ def supervised_finetune():
)
print(gemma2b_c4_alpaca.is_instruct_finetuned) # True
gemma2b_c4_alpaca.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca

# ============== Or maybe, we should censor curse words before SFT ==============
def remove_curse_words(sample_dict: dict) -> dict:
filter = lambda s: (
Expand All @@ -55,7 +57,7 @@ def remove_curse_words(sample_dict: dict) -> dict:
)
gemma2b_c4_alpaca_G.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca_G
alpaca_data_G.save_permanent_and_register() # saved to output/saved/saved_model/alpaca_gpt4_en_G.json & added to llama-factory dataset registry

# ============== What about using our own data (scattered across multiple files in multiple directories) for finetuning? ==============
histext_collection = DataFileCollection( # build a collection holding json files of year 1826 to 2018
collection_name="histext_1826_to_2018_collection",
Expand Down Expand Up @@ -93,6 +95,7 @@ def remove_nonstr_data(sample_dict: dict) -> dict:
result_model_name="gemma-2b_histext",
)


def direct_preference_optimization():
# ============== Then do DPO using ORCA data ==============
global gemma2b_c4_alpaca_orca
Expand All @@ -105,6 +108,7 @@ def direct_preference_optimization():
)
gemma2b_c4_alpaca_orca.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca_orca


def dialogue_manipulation():
# ============== Generating a dialogue, using a model to play the role of both user and assistant ==============
global llama8b_instruct
Expand All @@ -115,33 +119,33 @@ def dialogue_manipulation():
"input": "Is Eiffel Tower in Paris?",
"history": [
["What is the capital of France?", "Paris."],
]
],
}
]
],
)

def converse():
nonlocal dialogue_data

dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data2", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_user()

dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data3", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_assistant()

for i in range(5):
converse()

print(list(dialogue_data.all_passages()))
print(list(dialogue_data.to_openai_format()))


if __name__ == "__main__":
continue_pretrain()
supervised_finetune()
direct_preference_optimization()
dialogue_manipulation()
dialogue_manipulation()
15 changes: 9 additions & 6 deletions examples/abstractions/inference_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def dataset_inference_example(histllama: Model):

def logprob_example(histllama: Model):
custom_data = Data(
"custom_data",
"custom_data",
data_type="sft",
data_content = [
data_content=[
{
"input": "What is the capital of France?",
"predict": ["Paris", "Washington D.C.", "London", "Berlin"],
Expand All @@ -42,9 +42,12 @@ def logprob_example(histllama: Model):
custom_data.set_key_fields(query_field_name="input")

logprob_output: Data = histllama.inference(
custom_data, "8B-C021-infer-custom-deepspeed", backend="sglang", purpose="logprobs"
custom_data,
"8B-C021-infer-custom-deepspeed",
backend="sglang",
purpose="logprobs",
)
print(list(logprob_output.all_passages()))
print(list(logprob_output.all_passages()))
# [{'predict': ['Paris', 'Washington D.C.', 'London', 'Berlin'], 'input': 'What is the capital of France?', 'logprob': [-9.92294692993164, -17.21290510520339, -11.677074432373047, -12.903636932373047]}]


Expand All @@ -58,6 +61,6 @@ def logprob_example(histllama: Model):
# model_path_or_repoid="mistralai/Mixtral-8x7B-Instruct-v0.1",
# template_type="mistral",
# )

dataset_inference_example(histllama)
logprob_example(histllama)
logprob_example(histllama)
5 changes: 4 additions & 1 deletion libs/llama_factory/src/llmtuner/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,10 @@ def get_template_and_fix_tokenizer(
]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
slots=[
{"bos_token"},
"<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>",
]
),
format_observation=StringFormatter(
slots=[
Expand Down
Loading

0 comments on commit 8955a22

Please sign in to comment.