From 3397a4a41f68862e041b8284d306c5fba7ac4fb0 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Tue, 23 Jul 2024 14:10:38 -0700 Subject: [PATCH 01/31] initial chat datamodule --- mttl/cli/jsonl_to_hf_chat_dataset.py | 44 +++++++++++++++++++++ mttl/datamodule/base.py | 9 +++++ mttl/datamodule/chat_data_module.py | 14 +++++++ mttl/evaluators/evaluators.py | 2 +- projects/modular_llm/src/transfer_matrix.py | 2 +- 5 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 mttl/cli/jsonl_to_hf_chat_dataset.py create mode 100644 mttl/datamodule/chat_data_module.py diff --git a/mttl/cli/jsonl_to_hf_chat_dataset.py b/mttl/cli/jsonl_to_hf_chat_dataset.py new file mode 100644 index 000000000..c3ef5279e --- /dev/null +++ b/mttl/cli/jsonl_to_hf_chat_dataset.py @@ -0,0 +1,44 @@ +import os +import json + +import click +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +@click.command() +@click.option('--input_jsonl', help='Path to the input jsonl file') +@click.option('--model', help='Model name or path to the model') +@click.option( + '--output_dataset', + help='Path to the output hf dataset. Same as input file but withouth the extension if not provided.', + default=None, +) +def main(input_jsonl, model, output_dataset): + if output_dataset is None: + output_dataset, _ = os.path.splitext(input_jsonl) + dataset = load_dataset("json", data_files=input_jsonl) + + tokenizer = AutoTokenizer.from_pretrained(model) + model = AutoModelForCausalLM.from_pretrained(model) + + dataset = dataset.map( + lambda x: { + "formatted_chat": tokenizer.apply_chat_template( + json.loads(x["messages"]), + tokenize=False, + add_generation_prompt=False, + ), + "task_name": json.loads( + x["metadata"] if x["metadata"] else "{}" + ).get("task_name", "unknown"), + }, + num_proc=os.environ.get("MTTL_NUM_PROC_DATASETS", 16), + ) + + dataset.save_to_disk(output_dataset) + + +if __name__ == "__main__": + main() + diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 8cbb8ee3d..0e66a0fcd 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -652,6 +652,10 @@ def get_datamodule(args, for_generation=False, dataset_override=None): WinograndeDataConfig, WinograndeMultiChoiceDataModule, ) + from mttl.datamodule.chat_data_module import ( + ChatDataConfig, + ChatDataModule, + ) # refactor all the common arguments below into a dict common kwargs dataset = args.dataset if not dataset_override else dataset_override @@ -721,6 +725,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None): assert not for_generation config = dataset_to_klass_map[dataset][0] dm = dataset_to_klass_map[dataset][1](config) + elif "chat" in dataset: + config = ChatDataConfig( + **common_kwargs, + ) + dm = ChatDataModule(config, for_generation=for_generation) elif "flan" in dataset: config = FlanConfig( **common_kwargs, diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py new file mode 100644 index 000000000..428129f91 --- /dev/null +++ b/mttl/datamodule/chat_data_module.py @@ -0,0 +1,14 @@ +from mttl.datamodule.base import DatasetConfig, DefaultDataModule +from mttl.models.library.expert_library import DatasetLibrary + + +class ChatDataConfig(DatasetConfig): + pass + + +class ChatDataModule(DefaultDataModule): + + def setup_dataset(self): + dataset = DatasetLibrary.pull_dataset(self.config.dataset, split="train") + # TODO: continue implementation + diff --git a/mttl/evaluators/evaluators.py b/mttl/evaluators/evaluators.py index 9b609f234..8861f25e2 100644 --- a/mttl/evaluators/evaluators.py +++ b/mttl/evaluators/evaluators.py @@ -4,7 +4,6 @@ import torch -from mttl.callbacks import TestLossEvaluator from mttl.datamodule.base import DefaultDataModule, get_datamodule from mttl.evaluators import MMLUEvaluator, RougeEvaluator from mttl.models.expert_config import ExpertConfig @@ -18,6 +17,7 @@ def prepare_evaluator( subsample=-1, for_generation=None, ): + from mttl.callbacks import TestLossEvaluator if args.eval_metric == "loss": EVAL_CLASS = TestLossEvaluator for_generation = for_generation if for_generation is not None else False diff --git a/projects/modular_llm/src/transfer_matrix.py b/projects/modular_llm/src/transfer_matrix.py index 170dcfc5e..5576262a2 100644 --- a/projects/modular_llm/src/transfer_matrix.py +++ b/projects/modular_llm/src/transfer_matrix.py @@ -18,7 +18,7 @@ from mttl.models.library.expert_library import ExpertLibrary from mttl.utils import remote_login from mttl.vllm_engines.engines import free_memory -from projects.modular_llm.src.utils.evaluators import Evaluator, prepare_evaluator +from mttl.evaluators.evaluators import Evaluator, prepare_evaluator DEBUG = False if "AMLT_OUTPUT_DIR" in os.environ: From 8a1fea8708d7eefc2524db910cd563838b570b5d Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Thu, 25 Jul 2024 07:25:09 -0700 Subject: [PATCH 02/31] Split chat into individual chat turns --- mttl/cli/jsonl_to_hf_chat_dataset.py | 70 ++++++++++++++++++++-------- mttl/datamodule/chat_data_module.py | 21 ++++++++- 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/mttl/cli/jsonl_to_hf_chat_dataset.py b/mttl/cli/jsonl_to_hf_chat_dataset.py index c3ef5279e..88036c1ed 100644 --- a/mttl/cli/jsonl_to_hf_chat_dataset.py +++ b/mttl/cli/jsonl_to_hf_chat_dataset.py @@ -1,39 +1,70 @@ -import os import json +import os import click from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer @click.command() -@click.option('--input_jsonl', help='Path to the input jsonl file') -@click.option('--model', help='Model name or path to the model') +@click.option("--input_jsonl", help="Path to the input jsonl file") +@click.option("--model", help="Model name or path to the model") @click.option( - '--output_dataset', - help='Path to the output hf dataset. Same as input file but withouth the extension if not provided.', + "--output_dataset", + help="Path to the output hf dataset. Same as input file but with no extension if not provided.", default=None, ) def main(input_jsonl, model, output_dataset): if output_dataset is None: output_dataset, _ = os.path.splitext(input_jsonl) - dataset = load_dataset("json", data_files=input_jsonl) + + num_proc = os.environ.get("MTTL_NUM_PROC_DATASETS", 16) tokenizer = AutoTokenizer.from_pretrained(model) - model = AutoModelForCausalLM.from_pretrained(model) + + def apply_chat_template(example): + return tokenizer.apply_chat_template( + example, + tokenize=False, + add_generation_prompt=False, + ) + + def chat_progression(examples): + """Split chat into individual chat turns. For each turn, the source is + the chat up to that point and the target is the assistant's message.""" + sources = [] + targets = [] + task_names = [] + num_rounds = [] + for messages, metadata in zip(examples["messages"], examples["metadata"]): + messages = json.loads(messages) + task_name = json.loads(metadata or "{}").get("task_name", "unknown") + chat_progression = [] + rounds = 1 + for message in messages: + if message["role"] != "assistant": + chat_progression.append(message) + else: + sources.append(apply_chat_template(list(chat_progression))) + targets.append(apply_chat_template([dict(message)])) + task_names.append(task_name) + num_rounds.append(rounds) + chat_progression.append(message) + rounds += 1 + return { + "source": sources, + "target": targets, + "task_name": task_names, + "round": num_rounds, + } + + dataset = load_dataset("json", data_files=input_jsonl) dataset = dataset.map( - lambda x: { - "formatted_chat": tokenizer.apply_chat_template( - json.loads(x["messages"]), - tokenize=False, - add_generation_prompt=False, - ), - "task_name": json.loads( - x["metadata"] if x["metadata"] else "{}" - ).get("task_name", "unknown"), - }, - num_proc=os.environ.get("MTTL_NUM_PROC_DATASETS", 16), + chat_progression, + batched=True, # allows to return more examples than the input + remove_columns=dataset["train"].column_names, + num_proc=num_proc, ) dataset.save_to_disk(output_dataset) @@ -41,4 +72,3 @@ def main(input_jsonl, model, output_dataset): if __name__ == "__main__": main() - diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py index 428129f91..5d4315027 100644 --- a/mttl/datamodule/chat_data_module.py +++ b/mttl/datamodule/chat_data_module.py @@ -3,12 +3,29 @@ class ChatDataConfig(DatasetConfig): - pass + chat_template: str = None # TODO: load and apply custom chat template + seed: str = 42 class ChatDataModule(DefaultDataModule): def setup_dataset(self): dataset = DatasetLibrary.pull_dataset(self.config.dataset, split="train") - # TODO: continue implementation + num_examples = len(dataset) + num_train = int(0.8 * num_examples) + num_dev = int(0.1 * num_examples) + + dataset = dataset.shuffle(seed=self.config.seed) + + # use maybe_filter_hf_dataset_by_task instead? + self._task_names = [] + self._task_to_id = {} + # self._task_names = sorted(list(set(dataset['task_name']))) + # self._task_to_id = { + # task_name: i for i, task_name in enumerate(self._task_names) + # } + + self.train_dataset = dataset.select(range(num_train)) + self.dev_dataset = dataset.select(range(num_train, num_train + num_dev)) + self.test_dataset = dataset.select(range(num_train + num_dev, num_examples)) From ea0ab8bcb36c7f282070d1b8a726fd4b726657de Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Thu, 25 Jul 2024 11:34:33 -0700 Subject: [PATCH 03/31] black and isort --- mttl/datamodule/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 0e66a0fcd..4481a751a 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -624,6 +624,7 @@ def collate_fn(self): def get_datamodule(args, for_generation=False, dataset_override=None): from mttl.datamodule.arc_data_module import ArcDataConfig, ArcMultiChoiceDataModule + from mttl.datamodule.chat_data_module import ChatDataConfig, ChatDataModule from mttl.datamodule.codex_data_module import CodexDataConfig, CodexDataModule from mttl.datamodule.hellaswag_data_module import ( HellaswagDataConfig, @@ -652,10 +653,6 @@ def get_datamodule(args, for_generation=False, dataset_override=None): WinograndeDataConfig, WinograndeMultiChoiceDataModule, ) - from mttl.datamodule.chat_data_module import ( - ChatDataConfig, - ChatDataModule, - ) # refactor all the common arguments below into a dict common kwargs dataset = args.dataset if not dataset_override else dataset_override From 5b2e3cdb8af1273313d1798e4dfde18f90cc056f Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 31 Jul 2024 11:56:42 -0700 Subject: [PATCH 04/31] Added get_clusters from #70 --- mttl/models/ranker/classifier_ranker.py | 7 ++ projects/modular_llm/get_clusters.py | 141 ++++++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 projects/modular_llm/get_clusters.py diff --git a/mttl/models/ranker/classifier_ranker.py b/mttl/models/ranker/classifier_ranker.py index d7b7312d6..f5907d10e 100644 --- a/mttl/models/ranker/classifier_ranker.py +++ b/mttl/models/ranker/classifier_ranker.py @@ -85,6 +85,13 @@ def __init__( self.out_projecter = nn.Linear(hidden_size, self.num_labels) self.save_hyperparameters(ignore=["text_encoder"]) + def get_text_encode(self, x): + # Encode the text input + text_output = self.text_encoder(x) + # conver the text output to hidden vector + text_output_projecter = self.text_projecter(text_output) + return text_output_projecter + def forward(self, x): # Encode the text input text_output = self.text_encoder(x) diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py new file mode 100644 index 000000000..2835d2712 --- /dev/null +++ b/projects/modular_llm/get_clusters.py @@ -0,0 +1,141 @@ +import argparse +import os +from datetime import datetime + +import numpy as np +import torch +from pytorch_lightning import seed_everything +from sentence_transformers import SentenceTransformer +from sklearn.cluster import KMeans +from torch.utils.data import DataLoader, Subset +from tqdm import tqdm + +from mttl.logging import logger, setup_logging +from mttl.models.library.expert_library import DatasetLibrary +from mttl.models.ranker.adapter_ranker import AdapterRankerHelper +from mttl.utils import remote_login + + +def get_dataset(args): + dataset = DatasetLibrary.pull_dataset(args.dataset, split="train") + + # create the subsample of the dataset if cutoff is set. + if args.cutoff > 0: + dataset = dataset.shuffle(seed=args.seed) + dataset = dataset.select(range(args.cutoff)) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + np.random.shuffle(indices) + split = int(np.floor(args.subsample * dataset_size)) + subset_indices = indices[:split] + subset_dataset = Subset(dataset, subset_indices) + + train_dataloader = DataLoader( + subset_dataset, batch_size=args.batch_size, num_workers=args.num_workers + ) + all_dataloader = DataLoader( + dataset, batch_size=args.batch_size, num_workers=args.num_workers + ) + + return train_dataloader, all_dataloader, dataset + + +class EncoderModel: + + def __init__(self, encoding, model, device="cuda"): + if encoding == "classifier": + model = AdapterRankerHelper.get_ranker_instance( + ranker_model="classifier", + ranker_path=model, + device=device, + ) + model.encoder_fn = model.get_text_encode + postprocess = lambda x: x.cpu().detach().numpy() + elif encoding == "embedding": + model = SentenceTransformer(model, device=device) + model.encoder_fn = model.encode + postprocess = lambda x: x + else: + raise ValueError(f"Invalid encoding: {encoding}") + + self.model = model.to(device) + self.encoding = encoding + self.postprocess = postprocess + + def get_text_encode(self, text): + return self.postprocess(self.model.encoder_fn(text)) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--encoding", type=str, default="embedding") + parser.add_argument( + "--model", type=str, default="sentence-transformers/sentence-t5-xxl" + ) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument( + "--cutoff", type=int, default=-1, help="Number of examples to use." + ) + parser.add_argument("--subsample", type=float, default=0.2) + parser.add_argument("--num_clusters", type=int, default=8) + parser.add_argument( + "--output_dir", type=str, default=os.getenv("OUTPUT_DIR", "./output") + ) + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of workers for dataloader" + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + setup_logging(args.output_dir) + logger.info("Args: {}".format(vars(args))) + remote_login() + seed_everything(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = EncoderModel(args.encoding, args.model, device=device) + train_dataloader, all_dataloader, all_dataset = get_dataset(args) + + with torch.no_grad(): + embedding_list = [ + model.get_text_encode(batch["source"]) + for batch in tqdm( + train_dataloader, total=len(train_dataloader), desc="dataset" + ) + ] + + all_embedding = np.concatenate(embedding_list, axis=0) + logger.info(f"all_embedding shape: {all_embedding.shape}") + + kmeans = KMeans( + n_clusters=args.num_clusters, + init="k-means++", + n_init=10, + random_state=args.seed, + ).fit(all_embedding) + + def add_cluster_id(example): + embedding = model.get_text_encode(example["source"]) + example["cluster_id"] = [str(i) for i in kmeans.predict(embedding)] + return example + + dataset = all_dataset.map(add_cluster_id, batched=True, batch_size=args.batch_size) + + dataset_name = ( + f"local://{args.output_dir}/clusters-" + f"{args.num_clusters}-{datetime.now().isoformat()}" + ) + + logger.info(f"Pushing dataset to {dataset_name}") + DatasetLibrary.push_dataset(dataset, dataset_name) + + +if __name__ == "__main__": + main() From df2e2aade3d37fd11fa8c018b223bce1daba9d78 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Fri, 2 Aug 2024 05:40:47 -0700 Subject: [PATCH 05/31] lazy (literaly) setup_dataset and get_datamodule --- mttl/datamodule/base.py | 2 +- mttl/datamodule/chat_data_module.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 4481a751a..4319601c0 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -722,7 +722,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None): assert not for_generation config = dataset_to_klass_map[dataset][0] dm = dataset_to_klass_map[dataset][1](config) - elif "chat" in dataset: + elif "chat" in dataset or "clusters" in dataset: config = ChatDataConfig( **common_kwargs, ) diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py index 5d4315027..69d13fb26 100644 --- a/mttl/datamodule/chat_data_module.py +++ b/mttl/datamodule/chat_data_module.py @@ -18,13 +18,12 @@ def setup_dataset(self): dataset = dataset.shuffle(seed=self.config.seed) - # use maybe_filter_hf_dataset_by_task instead? - self._task_names = [] - self._task_to_id = {} - # self._task_names = sorted(list(set(dataset['task_name']))) - # self._task_to_id = { - # task_name: i for i, task_name in enumerate(self._task_names) - # } + dataset = dataset.rename_column("task_name", "origin_task_name") + dataset = dataset.rename_column("cluster_id", "task_name") + self._task_names = sorted(list(set(dataset["task_name"]))) + self._task_to_id = { + task_name: i for i, task_name in enumerate(self._task_names) + } self.train_dataset = dataset.select(range(num_train)) self.dev_dataset = dataset.select(range(num_train, num_train + num_dev)) From f1e6ac896c9c0ad4941b3a80c3ac53ab55d6096a Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 2 Aug 2024 11:39:19 -0700 Subject: [PATCH 06/31] add pack sequences --- mttl/datamodule/base.py | 2 ++ mttl/datamodule/chat_data_module.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index c41940ae1..edd6ef1e3 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -777,6 +777,7 @@ def post_setup_dataset(self): if self.config.pack_sequences and split == "train": dataset = getattr(self, f"{split}_dataset") logger.info(f"Packing sequences for {split} dataset") + dataset = self.tokenize_dataset(dataset) dataset = self.pack_sequences( dataset, max_sequences=self.config.max_seq_per_pack @@ -883,6 +884,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None): "pad_to_multiple_of": args.pad_to_multiple_of, "padding_side": args.padding_side, "max_seq_per_pack": args.max_seq_per_pack, + "pack_sequences": args.pack_sequences, } if dataset in [ diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py index 69d13fb26..fb91e3fb7 100644 --- a/mttl/datamodule/chat_data_module.py +++ b/mttl/datamodule/chat_data_module.py @@ -8,7 +8,6 @@ class ChatDataConfig(DatasetConfig): class ChatDataModule(DefaultDataModule): - def setup_dataset(self): dataset = DatasetLibrary.pull_dataset(self.config.dataset, split="train") From a944701da27e29386ef6da203004946fbc88e3f9 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 2 Aug 2024 15:56:16 -0700 Subject: [PATCH 07/31] update config file --- projects/modular_llm/configs/models/mistral_7b.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/modular_llm/configs/models/mistral_7b.json b/projects/modular_llm/configs/models/mistral_7b.json index 96b75e0dc..c560a41c8 100644 --- a/projects/modular_llm/configs/models/mistral_7b.json +++ b/projects/modular_llm/configs/models/mistral_7b.json @@ -18,8 +18,8 @@ "model_family": "gpt", "optimizer": "adamw", "warmup_proportion": 0.06, - "max_input_length": 2048, - "max_output_length": 64, + "max_input_length": 4096, + "max_output_length": 1024, "truncation_side": "left", "pipeline_eval_tasks": "all" } \ No newline at end of file From 01a2113e71f17e9f544f035f03e19671aaca03cb Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:22:35 -0700 Subject: [PATCH 08/31] quant config --- mttl/models/utils.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mttl/models/utils.py b/mttl/models/utils.py index 09225c431..484f0a2ad 100644 --- a/mttl/models/utils.py +++ b/mttl/models/utils.py @@ -10,6 +10,7 @@ import pytorch_lightning as pl import torch from pytorch_lightning import LightningModule +from transformers import BitsAndBytesConfig from transformers.file_utils import PushToHubMixin from transformers.utils import cached_file @@ -446,6 +447,18 @@ def model_loader_helper( from transformers import AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel + if load_in_8bit: + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + elif load_in_4bit: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + else: + bnb_config = None + logger.info(f"Attention Implementation: {attn_implementation}") if isinstance(model_name, PreTrainedModel): @@ -454,8 +467,7 @@ def model_loader_helper( if "llama" in model_name: model_object = LlamaForCausalLM.from_pretrained( model_name, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, + quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map=device_map, attn_implementation=attn_implementation, @@ -468,7 +480,7 @@ def model_loader_helper( logger.info(f"Loading phi-2 model from {os.environ['PHI_PATH']}") model_object = AutoModelForCausalLM.from_pretrained( os.environ["PHI_PATH"], - load_in_8bit=load_in_8bit, + quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map=device_map, trust_remote_code=True, @@ -482,8 +494,7 @@ def model_loader_helper( model_object = AutoModelForCausalLM.from_pretrained( model_name, device_map=device_map, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, + quantization_config=bnb_config, trust_remote_code=True, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16, From d83d8a855364ec793a9233380d81e3260b54d29b Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:22:54 -0700 Subject: [PATCH 09/31] task name can be int --- mttl/datamodule/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mttl/datamodule/utils.py b/mttl/datamodule/utils.py index 55207087a..4e5b57357 100644 --- a/mttl/datamodule/utils.py +++ b/mttl/datamodule/utils.py @@ -19,12 +19,8 @@ def maybe_filter_hf_dataset_by_task( if "test" in dataset: all_tasks = all_tasks.union(set(dataset["test"][task_field])) - if task_names: - task_names = ( - sorted(task_names.split(",")) - if isinstance(task_names, str) - else sorted(task_names) - ) + if task_names is not None: + task_names = sorted(str(task_names).split(",")) if not set(task_names).issubset(all_tasks): raise ValueError( "task_names must be a subset of the available tasks. Got {} and {}".format( From 778ca090c2123b3f0dcbe413626572efb5e17e27 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:24:04 -0700 Subject: [PATCH 10/31] chat datamodule filter task name --- mttl/datamodule/chat_data_module.py | 53 ++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py index fb91e3fb7..bc1e8128c 100644 --- a/mttl/datamodule/chat_data_module.py +++ b/mttl/datamodule/chat_data_module.py @@ -1,29 +1,48 @@ +import os +from functools import partial + +import numpy + from mttl.datamodule.base import DatasetConfig, DefaultDataModule +from mttl.datamodule.mt_seq_to_seq_module import ( + FlatMultiTaskConfig, + apply_source_template, +) +from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task +from mttl.logging import logger from mttl.models.library.expert_library import DatasetLibrary -class ChatDataConfig(DatasetConfig): - chat_template: str = None # TODO: load and apply custom chat template +class ChatDataConfig(FlatMultiTaskConfig): seed: str = 42 class ChatDataModule(DefaultDataModule): def setup_dataset(self): - dataset = DatasetLibrary.pull_dataset(self.config.dataset, split="train") - - num_examples = len(dataset) - num_train = int(0.8 * num_examples) + n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16)) + dataset = DatasetLibrary.pull_dataset(self.config.dataset) + + ( + self._task_names, + self._task_to_id, + train_dataset, + _, + _, + ) = maybe_filter_hf_dataset_by_task( + dataset, + "cluster_id", + self.config.finetune_task_name, + n_proc=n_proc, + ) + + num_examples = len(train_dataset) + num_train = int(0.9 * num_examples) num_dev = int(0.1 * num_examples) - dataset = dataset.shuffle(seed=self.config.seed) - - dataset = dataset.rename_column("task_name", "origin_task_name") - dataset = dataset.rename_column("cluster_id", "task_name") - self._task_names = sorted(list(set(dataset["task_name"]))) - self._task_to_id = { - task_name: i for i, task_name in enumerate(self._task_names) - } + train_dataset = train_dataset.shuffle(seed=self.config.seed) + train_dataset = train_dataset.rename_column("task_name", "original_task_name") + train_dataset = train_dataset.rename_column("cluster_id", "task_name") - self.train_dataset = dataset.select(range(num_train)) - self.dev_dataset = dataset.select(range(num_train, num_train + num_dev)) - self.test_dataset = dataset.select(range(num_train + num_dev, num_examples)) + self.train_dataset = train_dataset.select(range(num_train)) + self.dev_dataset = train_dataset.select(range(num_train, num_train + num_dev)) + self.test_dataset = self.dev_dataset From 40d974b48d58a4cf8eea0feb25b0f7c31edb25f6 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:24:24 -0700 Subject: [PATCH 11/31] prepare kbit on 4bit --- mttl/models/expert_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index edcf7a7ef..9f48409ed 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -59,15 +59,15 @@ def __init__(self, **kwargs): attn_implementation=getattr(self.hparams, "attn_implementation", None), ) - if self.load_in_8bit: - model_object = prepare_model_for_kbit_training(model_object) - # rebuild the training config, a bit cumbersome, but that's life self.training_config = ExpertConfig.fromdict(kwargs) self.training_config.vocab_size = ( model_object.get_input_embeddings().num_embeddings ) + if self.load_in_8bit or self.load_in_4bit: + model_object = prepare_model_for_kbit_training(model_object) + # init the transformer just with the modifier config, this avoids # passing the whole training config to the modify_transformer func self.modifier_config = ModifierConfig.from_training_config(self.training_config) @@ -83,6 +83,7 @@ def forward(self, batch, reduction="mean"): input_ids = batch["input_ids"] labels = batch["labels"] + print(input_ids.shape[-1]) outputs = self.model.forward(input_ids, attention_mask=batch["attention_mask"]) # calculate loss, could also be done inside of the model From 1b60a8acfc0f8197af536e3293c900a656e1848e Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:24:46 -0700 Subject: [PATCH 12/31] paged optim on bnb --- mttl/models/get_optimizer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mttl/models/get_optimizer.py b/mttl/models/get_optimizer.py index e85755147..f77a48164 100644 --- a/mttl/models/get_optimizer.py +++ b/mttl/models/get_optimizer.py @@ -7,6 +7,13 @@ from mttl.logging import logger +def instantiate_bnb_optimizer(model_parameters, **kwargs): + import bitsandbytes as bnb + + optimizer = bnb.optim.PagedAdamW(model_parameters, **kwargs) + return optimizer + + def get_optimizer(model, args, no_decay=None): """ Construct optimizer based on args @@ -92,7 +99,10 @@ def get_optimizer(model, args, no_decay=None): # from transformers import AdamW # tloen uses adamw_torch from torch.optim import AdamW - optimizer = AdamW(param_groups, eps=args.adam_epsilon) + if args.load_in_4bit or args.load_in_8bit: + optimizer = instantiate_bnb_optimizer(param_groups, eps=args.adam_epsilon) + else: + optimizer = AdamW(param_groups, eps=args.adam_epsilon) elif optim_name.lower() == "adafactor": optimizer = Adafactor( param_groups, From 8b06f7c44b2a8b4fc0e4192c56a9792187e43ee5 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Sat, 3 Aug 2024 15:25:06 -0700 Subject: [PATCH 13/31] trust --- projects/modular_llm/get_clusters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py index 2835d2712..15a49a141 100644 --- a/projects/modular_llm/get_clusters.py +++ b/projects/modular_llm/get_clusters.py @@ -53,7 +53,7 @@ def __init__(self, encoding, model, device="cuda"): model.encoder_fn = model.get_text_encode postprocess = lambda x: x.cpu().detach().numpy() elif encoding == "embedding": - model = SentenceTransformer(model, device=device) + model = SentenceTransformer(model, device=device, trust_remote_code=True) model.encoder_fn = model.encode postprocess = lambda x: x else: From d724f20883e8732a8d8efee54e8fcd75e716b339 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 12:09:17 -0700 Subject: [PATCH 14/31] remove unused config options --- mttl/config.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mttl/config.py b/mttl/config.py index 9a0347d7c..58bdfef60 100644 --- a/mttl/config.py +++ b/mttl/config.py @@ -201,10 +201,7 @@ def _set_defaults(self): self.data_dir = os.getenv("TRAIN_DIR", "/tmp/") self.output_dir = os.getenv("OUTPUT_DIR", "./output") - self.finetune_task_name = None - self.example_to_ids_path = None # path to clustering of data - self.embeddings_path = None # NI related configs self.use_task_descriptions = False # Use task descriptions @@ -215,7 +212,6 @@ def _set_defaults(self): 0 # Use some few-shot examples if possible (applies to NI) ) - self.task_prefix = None # xfit has task prefixes detailing # of shots, seed, etc; this is automatically filled in at fine-tuning time self.exp_name = None self.wandb_project = None self.padding_side = "right" From c7fd876a7656fd24ed3ccca59deab395dc4cd46d Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 12:10:10 -0700 Subject: [PATCH 15/31] remove length printing --- mttl/models/expert_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 9f48409ed..d22e01fa3 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -83,7 +83,6 @@ def forward(self, batch, reduction="mean"): input_ids = batch["input_ids"] labels = batch["labels"] - print(input_ids.shape[-1]) outputs = self.model.forward(input_ids, attention_mask=batch["attention_mask"]) # calculate loss, could also be done inside of the model @@ -153,6 +152,18 @@ def training_step(self, batch, _): f"{self._log_pref}train/total_loss", total_loss, on_step=True, prog_bar=True ) + # get peak and avg memory + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + memory = torch.cuda.memory_allocated() / 1024**3 + + self.log( + f"{self._log_pref}train/peak_memory", + peak_memory, + on_step=True, + prog_bar=True, + ) + self.log(f"{self._log_pref}train/memory", memory, on_step=True, prog_bar=True) + for i, pg in enumerate(self.optimizers().optimizer.param_groups): self.log(f"train/lr_{i}", pg["lr"]) return total_loss From 19464d70e52af9339b8dd5bee4af8a48a3a673ee Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 14:08:23 -0700 Subject: [PATCH 16/31] remove some comments --- mttl/datamodule/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index edd6ef1e3..c9cc2db54 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -682,9 +682,6 @@ def pack_sequences(self, dataset, max_sequences=4, shuffle=True): if shuffle: dataset = dataset.shuffle(seed=42) - # TODO: first partition dataset according to `task_name`, and - # pack each task individually to ensure that we don't mix tasks - # Very basic code that will iterate over sequences one by one, # and merge together until the max_input_length is reached # This is not optimal, but it's a start @@ -707,7 +704,6 @@ def append_to_running_seq(container, example): else: raise ValueError(f"Unknown type {type(v)}") - # TODO: THis is SOMEHOW WRONG. CHECK. container["seq_lens"] += [len(example["input_ids"])] def add_finished_sequence(container, example): From 7166a208e09a83452900065f58574908ed3bba79 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 14:53:54 -0700 Subject: [PATCH 17/31] mistral for chatbot --- .../configs/chatbot/mistral_7b.json | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 projects/modular_llm/configs/chatbot/mistral_7b.json diff --git a/projects/modular_llm/configs/chatbot/mistral_7b.json b/projects/modular_llm/configs/chatbot/mistral_7b.json new file mode 100644 index 000000000..45e46d03d --- /dev/null +++ b/projects/modular_llm/configs/chatbot/mistral_7b.json @@ -0,0 +1,29 @@ +{ + "lora_rank": 4, + "lora_dropout": 0.05, + "weight_decay": 0.0, + "n_skills": 1, + "model_modifier":"lora", + "modify_modules": ".*", + "modify_layers": "q_proj|k_proj|v_proj|o_proj", + "trainable_param_names": ".*lora_[ab].*", + "num_train_epochs": 5, + "learning_rate": 1e-4, + "micro_batch_size": 1, + "train_batch_size": 16, + "load_in_8bit": 0, + "predict_batch_size": 10, + "precision": "bf16", + "model": "mistralai/Mistral-7B-v0.1", + "model_family": "gpt", + "optimizer": "adamw", + "warmup_proportion": 0.06, + "max_input_length": 4096, + "max_output_length": 1024, + "truncation_side": "left", + "pipeline_eval_tasks": "all", + "eval_before_training": "false", + "device_map": "cuda:0", + "pack_sequences": true, + "attn_implementation": "sdpa" +} \ No newline at end of file From 820f6539ce284645ceb75e5e22906ae7c3168001 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 15:27:29 -0700 Subject: [PATCH 18/31] task_name_field to filter on a sepecific column in the dataset --- mttl/datamodule/base.py | 70 +++++++++++++++++-------- mttl/datamodule/chat_data_module.py | 48 ----------------- mttl/datamodule/cluster_data_module.py | 27 ++++++++++ mttl/datamodule/mt_seq_to_seq_module.py | 44 +++++++++++----- requirements-dev.txt | 30 +++++++++++ tests/test_datamodules.py | 31 ++++++++++- 6 files changed, 168 insertions(+), 82 deletions(-) delete mode 100644 mttl/datamodule/chat_data_module.py create mode 100644 mttl/datamodule/cluster_data_module.py create mode 100644 requirements-dev.txt diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index c9cc2db54..354a91666 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -19,6 +19,8 @@ @dataclass class DatasetConfig: + """Generic dataclass for dataset and batching configuration.""" + dataset: str = None data_dir: str = None model: str = None @@ -32,21 +34,23 @@ class DatasetConfig: model_family: str = "gpt" train_on_inputs: bool = False add_eos_to_targets: bool = True - finetune_task_name: str = None subsample_train: int = None subsample_dev: int = None subsample_test: int = None - subsample_per_task: bool = False # Changing default to False + subsample_per_task: bool = False subsample: int = -1 pack_sequences: bool = False # True pad_to_multiple_of: int = 8 max_seq_per_pack: int = 4 + finetune_task_name: str = None + task_id_field: str = "task_id" + task_name_field: str = "task_name" + task_source_field: str = "task_source" @dataclass class DefaultCollator: - """Simple collator - + """ Converts a batch of examples into a batch of inputs and labels for a sequence to sequence task. If model_family is "gpt", then the inputs and outputs are constructed for a causal language model, e.g. concatenated in a single string and labels are set to be -100 for all tokens in the input. @@ -62,8 +66,15 @@ class DefaultCollator: model_family: str = "seq2seq" for_generation: bool = False train_on_inputs: bool = False - task_to_id: dict = None - add_eos_to_targets: bool = True + task_to_id: dict = None # mapping from task name to task id + add_eos_to_targets: bool = True # add eos token to the end of the target sequence + task_id_field: str = ( + "task_id" # where to read task id information from in the batch + ) + task_name_field: str = ( + "task_name" # where to read task name information from in the batch + ) + task_source_field: str = "task_source" def enforce_eos(self, targets): # simulate the default behaviour of LLamatokenizer, when adding eos token and truncating: the last token must always be eos @@ -353,9 +364,10 @@ def pad_sequence_wrapper(tensor_list, batch_first, padding_value, side="right"): # Otherwise process as expected sources = [b["source"] for b in batch] labels = [b["target"] for b in batch] - task_ids = [b.get("task_id", None) for b in batch] - task_names = [b.get("task_name", None) for b in batch] - task_sources = [b.get("task_source", None) for b in batch] + + task_ids = [b.get(self.task_id_field, None) for b in batch] + task_names = [b.get(self.task_name_field, None) for b in batch] + task_sources = [b.get(self.task_source_field, None) for b in batch] output_batch = ( self.prepare_inputs_for_gpt_family(sources, labels) @@ -372,7 +384,7 @@ def pad_sequence_wrapper(tensor_list, batch_first, padding_value, side="right"): [self.task_to_id[tn] for tn in task_names] ) elif has_task_ids: - output_batch["task_ids"] = torch.LongTensor(task_ids) + output_batch["task_ids"] = torch.LongTensor(list(map(int, task_ids))) if has_task_names and not has_task_sources: task_sources = task_names @@ -528,6 +540,9 @@ def collate_fn(self): train_on_inputs=self.config.train_on_inputs, add_eos_to_targets=self.config.add_eos_to_targets, task_to_id=self.task_to_id, + task_name_field=self.config.task_name_field, + task_id_field=self.config.task_id_field, + task_source_field=self.config.task_source_field, ) def print_infos(self): @@ -589,7 +604,6 @@ def subsample_dataset(self, dataset, n_samples, per_task=False): Raises: AssertionError: If `per_task` is True and the dataset is not an ArrowDataset. - """ def get_dst_idxs_sampled(n_samples, total_size): @@ -603,7 +617,7 @@ def get_dst_idxs_sampled(n_samples, total_size): # make this deterministic to always sample the same subset if isinstance(dataset, ArrowDataset): if per_task: - task_names = dataset.unique("task_name") + task_names = dataset.unique(self.config.task_name_field) subsampled_dataset = [] for i, task_name in enumerate(task_names): logger.info( @@ -612,7 +626,9 @@ def get_dst_idxs_sampled(n_samples, total_size): task_idxs = torch.tensor( [ index - for index, value in enumerate(dataset["task_name"]) + for index, value in enumerate( + dataset[self.config.task_name_field] + ) if value == task_name ] ) @@ -620,7 +636,12 @@ def get_dst_idxs_sampled(n_samples, total_size): task_idxs = task_idxs[idxs] task_dataset = dataset.select(task_idxs) subsampled_dataset.append(task_dataset) - assert all([t == task_name for t in task_dataset["task_name"]]) + assert all( + [ + t == task_name + for t in task_dataset[self.config.task_name_field] + ] + ) subsampled_dataset = concatenate_datasets(subsampled_dataset) else: idxs = get_dst_idxs_sampled(n_samples, total_size) @@ -644,7 +665,7 @@ def __init__( self.for_generation = for_generation self.tokenizer = get_tokenizer(config, for_generation=for_generation) self.setup_dataset() - self.post_setup_dataset() + self._post_setup_dataset() def setup(self, stage=None): pass @@ -755,10 +776,11 @@ def dict_get_item(ex, i): ) return dataset - def post_setup_dataset(self): + def _post_setup_dataset(self): + # subsample the splits if needed for split in ["train", "dev", "test"]: - subsample = getattr(self.config, f"subsample_{split}", None) + if subsample and subsample > 0: dataset = getattr(self, f"{split}_dataset") logger.warning( @@ -798,6 +820,9 @@ def collate_fn(self): train_on_inputs=self.config.train_on_inputs, task_to_id=self.task_to_id, add_eos_to_targets=self.config.add_eos_to_targets, + task_name_field=self.config.task_name_field, + task_id_field=self.config.task_id_field, + task_source_field=self.config.task_source_field, ) @@ -822,12 +847,15 @@ def collate_fn(self): task_to_id=self.task_to_id, multisource=True, add_eos_to_targets=self.config.add_eos_to_targets, + task_name_field=self.config.task_name_field, + task_id_field=self.config.task_id_field, + task_source_field=self.config.task_source_field, ) def get_datamodule(args, for_generation=False, dataset_override=None): from mttl.datamodule.arc_data_module import ArcDataConfig, ArcMultiChoiceDataModule - from mttl.datamodule.chat_data_module import ChatDataConfig, ChatDataModule + from mttl.datamodule.cluster_data_module import ClusterDataConfig, ClusterDataModule from mttl.datamodule.codex_data_module import CodexDataConfig, CodexDataModule from mttl.datamodule.hellaswag_data_module import ( HellaswagDataConfig, @@ -929,11 +957,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None): assert not for_generation config = dataset_to_klass_map[dataset][0] dm = dataset_to_klass_map[dataset][1](config) - elif "chat" in dataset or "clusters" in dataset: - config = ChatDataConfig( + elif "clusters" in dataset: + config = ClusterDataConfig( **common_kwargs, ) - dm = ChatDataModule(config, for_generation=for_generation) + dm = ClusterDataModule(config, for_generation=for_generation) elif "flan" in dataset: config = FlanConfig( **common_kwargs, diff --git a/mttl/datamodule/chat_data_module.py b/mttl/datamodule/chat_data_module.py deleted file mode 100644 index bc1e8128c..000000000 --- a/mttl/datamodule/chat_data_module.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from functools import partial - -import numpy - -from mttl.datamodule.base import DatasetConfig, DefaultDataModule -from mttl.datamodule.mt_seq_to_seq_module import ( - FlatMultiTaskConfig, - apply_source_template, -) -from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task -from mttl.logging import logger -from mttl.models.library.expert_library import DatasetLibrary - - -class ChatDataConfig(FlatMultiTaskConfig): - seed: str = 42 - - -class ChatDataModule(DefaultDataModule): - def setup_dataset(self): - n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16)) - dataset = DatasetLibrary.pull_dataset(self.config.dataset) - - ( - self._task_names, - self._task_to_id, - train_dataset, - _, - _, - ) = maybe_filter_hf_dataset_by_task( - dataset, - "cluster_id", - self.config.finetune_task_name, - n_proc=n_proc, - ) - - num_examples = len(train_dataset) - num_train = int(0.9 * num_examples) - num_dev = int(0.1 * num_examples) - - train_dataset = train_dataset.shuffle(seed=self.config.seed) - train_dataset = train_dataset.rename_column("task_name", "original_task_name") - train_dataset = train_dataset.rename_column("cluster_id", "task_name") - - self.train_dataset = train_dataset.select(range(num_train)) - self.dev_dataset = train_dataset.select(range(num_train, num_train + num_dev)) - self.test_dataset = self.dev_dataset diff --git a/mttl/datamodule/cluster_data_module.py b/mttl/datamodule/cluster_data_module.py new file mode 100644 index 000000000..b062822c5 --- /dev/null +++ b/mttl/datamodule/cluster_data_module.py @@ -0,0 +1,27 @@ +import os +from dataclasses import dataclass +from functools import partial + +import numpy + +from mttl.datamodule.base import DatasetConfig, DefaultDataModule +from mttl.datamodule.mt_seq_to_seq_module import ( + FlatMultiTaskConfig, + FlatMultiTaskModule, + apply_source_template, +) +from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task +from mttl.logging import logger +from mttl.models.library.expert_library import DatasetLibrary + + +@dataclass +class ClusterDataConfig(FlatMultiTaskConfig): + """Just adapts the FlatMultiTaskConfig to a dataset containing a column with a cluster id.""" + + task_name_field: str = "cluster_id" + task_id_field: str = "cluster_id" + + +class ClusterDataModule(FlatMultiTaskModule): + pass diff --git a/mttl/datamodule/mt_seq_to_seq_module.py b/mttl/datamodule/mt_seq_to_seq_module.py index b1ef32ce0..ea0fe941f 100644 --- a/mttl/datamodule/mt_seq_to_seq_module.py +++ b/mttl/datamodule/mt_seq_to_seq_module.py @@ -32,6 +32,8 @@ def augment_few_shot_task( max_input_length=None, seed=42, modify_task_source=True, + task_source_field="task_source", + task_name_field="task_name", ): if num_samples is None and few_shots is None: raise ValueError("Either num_samples or few_shots must be specified.") @@ -77,11 +79,11 @@ def map_to_few_shot(_, index): return { "source": prompt, "target": dataset[index]["target"], - "task_name": dataset[index]["task_name"], - "task_source": ( - "few_shot_{}".format(dataset[index]["task_source"]) + task_name_field: dataset[index][task_name_field], + task_source_field: ( + "few_shot_{}".format(dataset[index][task_source_field]) if modify_task_source - else dataset[index]["task_source"] + else dataset[index][task_source_field] ), "split": ( dataset[index]["split"] if "split" in dataset.column_names else None @@ -93,21 +95,31 @@ def map_to_few_shot(_, index): def augment_few_shot( - dataset, num_samples, tokenizer=None, max_input_length=None, seed=42 + dataset, + num_samples, + tokenizer=None, + max_input_length=None, + seed=42, + task_name_field="task_name", + task_source_field="task_source", ): """Augment the dataset with few-shot examples.""" import tqdm augmented_dataset = [] - for source in tqdm.tqdm(dataset.unique("task_name")): + for source in tqdm.tqdm(dataset.unique(task_name_field)): augmented_dataset.append( Dataset.from_list( augment_few_shot_task( - dataset.filter(lambda x: x["task_name"] == source), - num_samples, - tokenizer, - max_input_length, - seed, + dataset.filter(lambda x: x[task_name_field] == source), + num_samples=num_samples, + few_shots=None, + tokenizer=tokenizer, + max_input_length=max_input_length, + seed=seed, + modify_task_source=True, + task_name_field=task_name_field, + task_source_field=task_source_field, ) ) ) @@ -139,6 +151,9 @@ def setup_dataset(self): self.dataset = DatasetLibrary.pull_dataset_with_retry(self.config.dataset) n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16)) + if "train" not in self.dataset.column_names: + raise ValueError("Flat multi-task datasets must have a 'train' split!") + if "split" not in self.dataset.column_names["train"]: logger.warning( "Dataset *should* have a 'split' column, try removing the dataset manually from the cache! Creating a new 'split' column." @@ -160,7 +175,10 @@ def create_split(rng, _): _, _, ) = maybe_filter_hf_dataset_by_task( - self.dataset, "task_name", self.config.finetune_task_name, n_proc=n_proc + self.dataset, + self.config.task_name_field, + self.config.finetune_task_name, + n_proc=n_proc, ) train_dataset = apply_source_template( @@ -173,6 +191,8 @@ def create_split(rng, _): self.config.augment_few_shot, tokenizer=self.tokenizer, max_input_length=self.config.max_input_length, + task_name_field=self.config.task_name_field, + task_source_field=self.config.task_source_field, ) train_dataset_aug = train_dataset_aug.shuffle() train_dataset = train_dataset_aug.select(range(len(train_dataset))) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..855de7b72 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,30 @@ +transformers==4.42.0 +torch==2.3.1 +datasets==2.20.0 +pytorch-lightning==2.3.3 +accelerate +deepspeed +huggingface_hub +click +wandb +rouge +tqdm +pandas +sentence-transformers +fsspec[adl] +prettytable +rich +bitsandbytes +matplotlib +openai +ray +nevergrad +vllm +evaluate +seaborn +azure-storage-blob +azure-identity +einops +nltk +mocker +pytest \ No newline at end of file diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index 5e175a04b..03d0ad7f2 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -203,7 +203,36 @@ def test_truncation_side(tiny_flan_id): assert true[:10] == dec[:10] -def test_auto_module(tiny_flan_id): +def test_task_name_field(tmp_path): + """Tests whether task names are correctly extracted from the dataset.""" + from datasets import Dataset, DatasetDict, load_dataset + + dataset = [ + {"source": "a", "target": "b", "cluster_id": "0"}, + {"source": "c", "target": "d", "cluster_id": "1"}, + {"source": "e", "target": "f", "cluster_id": "0"}, + {"source": "g", "target": "h", "cluster_id": "1"}, + {"source": "g", "target": "h", "cluster_id": "1"}, + ] + dataset = DatasetDict({"train": Dataset.from_list(dataset)}) + dataset.save_to_disk(tmp_path / "mini_dataset") + + dataset_name = "local://" + str(tmp_path / "mini_dataset") + dataset_config = FlatMultiTaskConfig( + model="EleutherAI/gpt-neo-125m", + model_family="gpt", + dataset=dataset_name, + task_id_field="cluster_id", + task_name_field="cluster_id", + ) + datamodule = FlatMultiTaskModule(dataset_config) + batch = next(iter(datamodule.train_dataloader())) + assert "task_ids" in batch + assert "task_names" in batch + assert np.all(x in ["0", "1"] for x in batch["task_names"]) + + +def test_auto_modsule(tiny_flan_id): flan = FlanModule( FlanConfig( dataset=tiny_flan_id, From cd7dbc6aecc288ee6763d3b90a4652528ca50053 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 16:29:35 -0700 Subject: [PATCH 19/31] add support for a few flags in the embedding model --- projects/modular_llm/get_clusters.py | 38 ++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py index 15a49a141..4d453ca07 100644 --- a/projects/modular_llm/get_clusters.py +++ b/projects/modular_llm/get_clusters.py @@ -1,6 +1,7 @@ import argparse import os from datetime import datetime +from functools import partial import numpy as np import torch @@ -29,7 +30,7 @@ def get_dataset(args): np.random.shuffle(indices) split = int(np.floor(args.subsample * dataset_size)) subset_indices = indices[:split] - subset_dataset = Subset(dataset, subset_indices) + subset_dataset = dataset.select(subset_indices) train_dataloader = DataLoader( subset_dataset, batch_size=args.batch_size, num_workers=args.num_workers @@ -38,12 +39,13 @@ def get_dataset(args): dataset, batch_size=args.batch_size, num_workers=args.num_workers ) - return train_dataloader, all_dataloader, dataset + return train_dataloader, all_dataloader, dataset, subset_dataset class EncoderModel: - - def __init__(self, encoding, model, device="cuda"): + def __init__( + self, encoding, model, prompt_name=None, batch_size=None, device="cuda" + ): if encoding == "classifier": model = AdapterRankerHelper.get_ranker_instance( ranker_model="classifier", @@ -54,7 +56,12 @@ def __init__(self, encoding, model, device="cuda"): postprocess = lambda x: x.cpu().detach().numpy() elif encoding == "embedding": model = SentenceTransformer(model, device=device, trust_remote_code=True) - model.encoder_fn = model.encode + if prompt_name: + model.encoder_fn = partial( + model.encode, batch_size=batch_size, prompt_name=prompt_name + ) + else: + model.encoder_fn = model.encode postprocess = lambda x: x else: raise ValueError(f"Invalid encoding: {encoding}") @@ -71,6 +78,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--encoding", type=str, default="embedding") + parser.add_argument("--prompt_name", type=str, default=None) parser.add_argument( "--model", type=str, default="sentence-transformers/sentence-t5-xxl" ) @@ -100,8 +108,14 @@ def main(): seed_everything(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = EncoderModel(args.encoding, args.model, device=device) - train_dataloader, all_dataloader, all_dataset = get_dataset(args) + model = EncoderModel( + args.encoding, + args.model, + prompt_name=args.prompt_name, + batch_size=args.batch_size, + device=device, + ) + train_dataloader, all_dataloader, all_dataset, subset_dataset = get_dataset(args) with torch.no_grad(): embedding_list = [ @@ -126,8 +140,16 @@ def add_cluster_id(example): example["cluster_id"] = [str(i) for i in kmeans.predict(embedding)] return example - dataset = all_dataset.map(add_cluster_id, batched=True, batch_size=args.batch_size) + subset_dataset = subset_dataset.map( + add_cluster_id, batched=True, batch_size=args.batch_size + ) + breakpoint() + DatasetLibrary.push_dataset( + subset_dataset, + f"local://{args.output_dir}/subset-clusters-{args.num_clusters}-{datetime.now().isoformat()}", + ) + dataset = all_dataset.map(add_cluster_id, batched=True, batch_size=args.batch_size) dataset_name = ( f"local://{args.output_dir}/clusters-" f"{args.num_clusters}-{datetime.now().isoformat()}" From 0510f31f7ad94a9476603f877802c92c4c741fb4 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 5 Aug 2024 16:29:54 -0700 Subject: [PATCH 20/31] remove breakpoint, oops --- projects/modular_llm/get_clusters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py index 4d453ca07..15b4e80bb 100644 --- a/projects/modular_llm/get_clusters.py +++ b/projects/modular_llm/get_clusters.py @@ -143,7 +143,6 @@ def add_cluster_id(example): subset_dataset = subset_dataset.map( add_cluster_id, batched=True, batch_size=args.batch_size ) - breakpoint() DatasetLibrary.push_dataset( subset_dataset, f"local://{args.output_dir}/subset-clusters-{args.num_clusters}-{datetime.now().isoformat()}", From 7a474310609f82cd24f65ef72c81413d5693e936 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Wed, 7 Aug 2024 11:20:44 -0700 Subject: [PATCH 21/31] script to dump orca in json format + em topic script --- mttl/cli/prettify_orca.py | 39 ++++ projects/modular_llm/gpt4_tagging.py | 330 +++++++++++++++++++++++++++ 2 files changed, 369 insertions(+) create mode 100644 mttl/cli/prettify_orca.py create mode 100644 projects/modular_llm/gpt4_tagging.py diff --git a/mttl/cli/prettify_orca.py b/mttl/cli/prettify_orca.py new file mode 100644 index 000000000..97c42bf60 --- /dev/null +++ b/mttl/cli/prettify_orca.py @@ -0,0 +1,39 @@ +import json +import os + +import click +from datasets import load_dataset +from transformers import AutoTokenizer + + +@click.command() +@click.option("--input_jsonl", help="Path to the input jsonl file") +@click.option("--output_jsonl", help="Model name or path to the model") +def main(input_jsonl, output_jsonl): + num_proc = os.environ.get("MTTL_NUM_PROC_DATASETS", 16) + + def prettify(examples): + examples_ = {key: value for key, value in examples.items()} + examples_["messages"] = [] + examples_["metadata"] = [] + examples_["task_name"] = [] + for messages, metadata in zip(examples["messages"], examples["metadata"]): + messages = json.loads(messages) + task_name = json.loads(metadata or "{}").get("task_name", "unknown") + examples_["messages"].append(messages) + examples_["metadata"].append(metadata) + examples_["task_name"].append(task_name) + return examples_ + + dataset = load_dataset("json", data_files=input_jsonl) + dataset = dataset.map( + prettify, + batched=True, # allows to return more examples than the input + remove_columns=dataset["train"].column_names, + num_proc=num_proc, + ) + dataset["train"].to_json(output_jsonl) + + +if __name__ == "__main__": + main() diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py new file mode 100644 index 000000000..e9577f048 --- /dev/null +++ b/projects/modular_llm/gpt4_tagging.py @@ -0,0 +1,330 @@ +import asyncio +import json +import os +from collections import defaultdict + +import fire +import jinja2 +import numpy as np +import tenacity +from openai import AsyncAzureOpenAI +from tqdm import tqdm as ttqdm +from tqdm.asyncio import tqdm + +client = AsyncAzureOpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + azure_endpoint="https://gcrgpt4aoai7.openai.azure.com/", + api_version="2024-05-01-preview", +) +gpt_model = "gpt-4o-gs" + + +m_template = """ +The following instructions are associated with a specific tag. Your task is to create a precise and descriptive title for this tag, encapsulating a common aspect found in most of these instructions. + +{% for instruction in instructions %} +Instruction: +{{instruction}} + +{% endfor %} + +{% if previous_tag %} +The previous tag for this group of instructions was: +Tag: {{previous_tag}} +{% endif -%} + +Determine a better title that encapsulates a common aspect found in most of these instructions, please provide it in this format: +Tag: Descriptive title for the tag + +""" + + +e_template = """ +Select a tag from the list that most accurately represents the given instruction. If a specific tag accurately describes the instruction, prioritize it over a more generic one. + +{{instruction}} + +Tags: +{% for tag in tags %} +{{loop.index}}. {{tag}} +{% endfor %} + +Please indicate the most appropriate tag by providing ONLY the corresponding number (1-{{tags|length}}). Use the following format: +Tag number: Number + +""" + + +@tenacity.retry( + wait=tenacity.wait_random_exponential(min=10, max=60), + stop=tenacity.stop_after_attempt(100), +) +async def get_completions(prompt, num_completions=1, max_tokens=128): + response = await client.chat.completions.create( + model=gpt_model, + messages=[ + { + "role": "user", + "content": prompt, + }, + ], + stop="\n", + temperature=0.0, + max_tokens=max_tokens, + n=num_completions, + ) + return [choice.message.content for choice in response.choices] + + +async def get_tag(instruction): + return await get_new_tag([instruction], [""]) + + +async def get_new_tag(instructions, previous_tag): + response = await get_completions( + jinja2.Template(m_template).render( + instructions=instructions, previous_tag=previous_tag + ) + ) + response = response[0] + if response.startswith("Tag:"): + return response[len("Tag:") :].strip().rstrip(".").strip() + return None + + +async def assign_tag(instruction, tags): + response = await get_completions( + jinja2.Template(e_template).render(instruction=instruction, tags=tags), + max_tokens=10, + ) + response = response[0] + if response.startswith("Tag number:"): + try: + return int(response[len("Tag number:") :].strip()) - 1 + except: + return None + return None + + +def get_instructions(examples, return_metadata=False): + """Returns user instructions as a stream of messages. + + Optionally returns associated metadata indicating the turn number and whether the message is the last in the example. + """ + + def load_messages(line): + # we load the line as a json object + example = json.loads(line) + messages = example["messages"] + + yield from [ + (turn, m["content"]) + for turn, m in enumerate(messages) + if m["role"] == "user" + ] + + for example_id in range(len(examples)): + messages = list(load_messages(examples[example_id])) + for i, (turn, message) in enumerate(messages): + if return_metadata: + # message, example_index, turn, is last from this example id + is_last = i == len(messages) - 1 + yield message, example_id, turn, is_last + else: + yield message + + +def get_batch(dataloader, batch_size): + batch = [] + for _ in range(batch_size): + try: + batch.append(next(dataloader)) + except StopIteration: + return [] + return batch + + +async def train_jsonl_file(file_path, output_path, num_tags): + """E-M training for tagging instructions.""" + import random + + random.seed(42) + + tags = None + batch_size = num_tags * 25 + init_examples_per_tag = 5 + + with open(file_path, "r") as ifile: + train_examples = ifile.readlines() + random.shuffle(train_examples) + + train_loader = iter(get_instructions(train_examples)) + + # initialize the tags here, 10 examples per tag + if tags is None: + init_examples = [ + get_batch(train_loader, init_examples_per_tag) for _ in range(num_tags) + ] + tags = await tqdm.gather(*[get_tag(c) for c in init_examples]) + + for i, tag in enumerate(tags): + print(f"{i}.", tag) + + iteration = 0 + ofile = open(output_path, "w") + while iteration < 30: + batch = get_batch(train_loader, batch_size) + if not batch: + break + + # now get tags for the batch + tagged_examples = await tqdm.gather( + *[assign_tag(example, tags) for example in batch] + ) + + # now group examples by tags, not that some examples may not have tags + # at this point, so some tags might not be useful! + notag_examples = [] + grouped_examples = defaultdict(list) + for example, tag in zip(batch, tagged_examples): + if tag is None: + notag_examples.append(example) + continue + grouped_examples[int(tag)].append(example) + + # get "widow" tags, create a new tag for each group + for i in range(num_tags): + if i not in grouped_examples or len(grouped_examples[i]) == 1: + grouped_examples[i].extend( + get_batch(train_loader, 10 - len(grouped_examples[i])) + ) + + # now get the m-step for each group + groups, keys = [], [] + for key, group in grouped_examples.items(): + if len(group) >= 10: + group = group[:10] + + groups.append(group) + keys.append(key) + + print("M-step for # tags =", len(groups)) + + new_tags = await tqdm.gather( + *[get_new_tag(group, tags[key]) for key, group in zip(keys, groups)] + ) + new_tags_dict = dict(zip(keys, new_tags)) + + new_tags = [] + for i in range(num_tags): + if i in new_tags_dict: + line = "{:<100} {:>15}".format( + new_tags_dict[i], f"{len(grouped_examples[i])}/{batch_size} (*)" + ) + print(f"{i}.", line) + new_tags.append(new_tags_dict[i].strip()) + else: + print(f"{i}.", "{:<100}".format(tags[i].strip())) + new_tags.append(tags[i].strip()) + + tags = new_tags + iteration += 1 + + ofile.write(json.dumps({"iteration": iteration, "tags": tags}) + "\n") + ofile.flush() + ofile.close() + + +async def infer_jsonl_file(file_path, tags_file, output_path): + tags = None + end = False + batch_size = 100 + + # load the tags + with open(tags_file, "r") as ifile: + tags = json.loads(ifile.readlines()[-1])["tags"] + + for i, t in enumerate(tags): + print(f"{i}.", t.strip()) + + # load all the examples to be tagged + with open(file_path, "r") as ifile: + train_examples = ifile.readlines() + train_loader = list(get_instructions(train_examples, return_metadata=True)) + + ofile = open(output_path, "w") + progress_bar = ttqdm(total=len(train_loader)) + while not end: + batch = [] + metadata = [] + try: + for message, ex_id, turn, is_last in train_loader: + batch.append(message) + metadata.append((ex_id, turn, is_last)) + + if len(batch) == batch_size: + break + except StopIteration: + end = True + + # now get tags for the batch + tagged_examples = await tqdm.gather( + *[assign_tag(example, tags) for example in batch] + ) + + n_random = 0 + for i, tag in enumerate(tagged_examples): + if tag is None: + tag = np.random.randint(len(tags)) + n_random += 1 + + # now we have to inject the tags back in the original dataset + example_id, turn, is_last = metadata[i] + example = json.loads(train_examples[example_id]) + messages = example["messages"] + # tag this particular message with its corresponding cluster + messages[turn]["tag"] = tags[int(tag)] + train_examples[example_id] = json.dumps(example) + # if it is the last message for this example, write it to the file + if is_last: + ofile.write(train_examples[metadata[i][0]] + "\n") + + progress_bar.update(len(batch)) + print("Random tags:", n_random) + ofile.close() + + +async def train_(json_file_path, num_tags=100): + await train_jsonl_file( + json_file_path, + json_file_path.replace(".jsonl", "") + f"_{gpt_model}_tags-{num_tags}.jsonl", + num_tags=num_tags, + ) + + +async def infer_(json_file_path, tags_file): + await infer_jsonl_file( + json_file_path, + tags_file, + tags_file.replace(".jsonl", "") + f"_inferred.jsonl", + ) + + +class GPT4EMTagging: + def infer(self, tags_path, file_path, model="gpt-4o-gs"): + global gpt_model + + gpt_model = model + print("Working on...", file_path) + asyncio.run(infer_(file_path, tags_path)) + + def train(self, file_path, num_tags=100, model="gpt-4o-gs"): + global gpt_model + + gpt_model = model + print("Working on...", file_path) + asyncio.run(train_(file_path, num_tags=num_tags)) + + +if __name__ == "__main__": + fire.Fire(GPT4EMTagging) From b7a7ed988b0178bf7755a011e5b359274ab74aec Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Wed, 7 Aug 2024 12:42:17 -0700 Subject: [PATCH 22/31] some speed-ups --- projects/modular_llm/gpt4_tagging.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index e9577f048..54790168b 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -52,7 +52,7 @@ Please indicate the most appropriate tag by providing ONLY the corresponding number (1-{{tags|length}}). Use the following format: Tag number: Number -""" +Tag number:""" @tenacity.retry( @@ -95,14 +95,19 @@ async def get_new_tag(instructions, previous_tag): async def assign_tag(instruction, tags): response = await get_completions( jinja2.Template(e_template).render(instruction=instruction, tags=tags), - max_tokens=10, + max_tokens=5, ) response = response[0] - if response.startswith("Tag number:"): - try: - return int(response[len("Tag number:") :].strip()) - 1 - except: - return None + try: + # parse response right away + return int(response.strip()) - 1 + except: + # if the model has generated "tag number"... + if response.startswith("Tag number:"): + try: + return int(response[len("Tag number:") :].strip()) - 1 + except: + return None return None @@ -250,10 +255,10 @@ async def infer_jsonl_file(file_path, tags_file, output_path): # load all the examples to be tagged with open(file_path, "r") as ifile: train_examples = ifile.readlines() - train_loader = list(get_instructions(train_examples, return_metadata=True)) + train_loader = get_instructions(train_examples, return_metadata=True) ofile = open(output_path, "w") - progress_bar = ttqdm(total=len(train_loader)) + progress_bar = ttqdm(total=len(train_examples) + 100_000) while not end: batch = [] metadata = [] From 6b29acb4f9b469f83e2347056b89ab26a3281e84 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Wed, 7 Aug 2024 12:44:23 -0700 Subject: [PATCH 23/31] remove previous instruction --- projects/modular_llm/gpt4_tagging.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index 54790168b..727b987df 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -28,11 +28,6 @@ {% endfor %} -{% if previous_tag %} -The previous tag for this group of instructions was: -Tag: {{previous_tag}} -{% endif -%} - Determine a better title that encapsulates a common aspect found in most of these instructions, please provide it in this format: Tag: Descriptive title for the tag @@ -77,14 +72,12 @@ async def get_completions(prompt, num_completions=1, max_tokens=128): async def get_tag(instruction): - return await get_new_tag([instruction], [""]) + return await get_new_tag([instruction]) -async def get_new_tag(instructions, previous_tag): +async def get_new_tag(instructions): response = await get_completions( - jinja2.Template(m_template).render( - instructions=instructions, previous_tag=previous_tag - ) + jinja2.Template(m_template).render(instructions=instructions) ) response = response[0] if response.startswith("Tag:"): @@ -215,9 +208,7 @@ async def train_jsonl_file(file_path, output_path, num_tags): print("M-step for # tags =", len(groups)) - new_tags = await tqdm.gather( - *[get_new_tag(group, tags[key]) for key, group in zip(keys, groups)] - ) + new_tags = await tqdm.gather(*[get_new_tag(group) for group in groups]) new_tags_dict = dict(zip(keys, new_tags)) new_tags = [] From 7c70ac4c18fe44f86d709cb3aef9e01f0b4220c2 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Thu, 8 Aug 2024 06:16:07 -0700 Subject: [PATCH 24/31] resume inference from output file --- projects/modular_llm/gpt4_tagging.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index 727b987df..5019a6099 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -248,6 +248,14 @@ async def infer_jsonl_file(file_path, tags_file, output_path): train_examples = ifile.readlines() train_loader = get_instructions(train_examples, return_metadata=True) + try: + with open(output_path, "r") as ofile: + resume_from_line = ofile.readlines() + resume_from = len(resume_from_line) + print("Resuming from example id:", resume_from) + except: + resume_from = 0 + ofile = open(output_path, "w") progress_bar = ttqdm(total=len(train_examples) + 100_000) while not end: @@ -255,8 +263,12 @@ async def infer_jsonl_file(file_path, tags_file, output_path): metadata = [] try: for message, ex_id, turn, is_last in train_loader: - batch.append(message) - metadata.append((ex_id, turn, is_last)) + if ex_id >= resume_from: + batch.append(message) + metadata.append((ex_id, turn, is_last)) + else: + if is_last: + progress_bar.update(1) if len(batch) == batch_size: break From 003b556bfb1f7e6a6a2c374d34a344b14cd6fd1e Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 9 Aug 2024 05:47:17 -0700 Subject: [PATCH 25/31] contrastive tagging --- projects/modular_llm/gpt4_tagging.py | 106 +++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 14 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index 5019a6099..b59b52d7f 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -34,6 +34,28 @@ """ +m_contrastive_template = """ +The following two groups of instructions A and B, and each is associated with a specific tag. Your task is to create a precise and descriptive title for the tag of group B, encapsulating a common aspect found in most of these instructions. +The title should describe the instructions in group B in a way that contrasts with common aspects of instructions of the group A. + +{% for instruction in instructions_a %} +Instruction (A): +{{instruction}} + +{% endfor %} + +{% for instruction in instructions_b %} +Instruction (B): +{{instruction}} + +{% endfor %} + +Determine a better title for the instructions in group B that encapsulates a common aspect found in most of these instructions while being different from the instructions of group A, please provide it in this format: +Tag: Descriptive title for the tag + +""" + + e_template = """ Select a tag from the list that most accurately represents the given instruction. If a specific tag accurately describes the instruction, prioritize it over a more generic one. @@ -85,7 +107,23 @@ async def get_new_tag(instructions): return None +async def get_new_tag_contrastive(instructions_a, instructions_b): + response = await get_completions( + jinja2.Template(m_contrastive_template).render( + instructions_a=instructions_a, instructions_b=instructions_b + ) + ) + response = response[0] + if response.startswith("Tag:"): + return response[len("Tag:") :].strip().rstrip(".").strip() + return None + + async def assign_tag(instruction, tags): + if len(instruction.split()) >= 10_000: + print("too long instruction, skipping.") + return None + response = await get_completions( jinja2.Template(e_template).render(instruction=instruction, tags=tags), max_tokens=5, @@ -142,14 +180,14 @@ def get_batch(dataloader, batch_size): return batch -async def train_jsonl_file(file_path, output_path, num_tags): +async def train_jsonl_file(file_path, output_path, num_tags, mode="normal"): """E-M training for tagging instructions.""" import random random.seed(42) tags = None - batch_size = num_tags * 25 + batch_size = num_tags * 15 init_examples_per_tag = 5 with open(file_path, "r") as ifile: @@ -190,7 +228,16 @@ async def train_jsonl_file(file_path, output_path, num_tags): continue grouped_examples[int(tag)].append(example) - # get "widow" tags, create a new tag for each group + # print "entropy" of the distribution, normalize the length of each group and compute the normalized entropy + lengths = [len(v) for v in grouped_examples.values()] + norm_lengths = np.array(lengths) / np.sum(lengths) + entropy = -np.sum(norm_lengths * np.log(norm_lengths + 1e-6)) / np.log( + len(lengths) + ) + print("Entropy of tags assignments: ", entropy) + + # get "widow" tags, create a new tag for each group, here we sample random examples + # until each group has 10 examples for i in range(num_tags): if i not in grouped_examples or len(grouped_examples[i]) == 1: grouped_examples[i].extend( @@ -200,6 +247,7 @@ async def train_jsonl_file(file_path, output_path, num_tags): # now get the m-step for each group groups, keys = [], [] for key, group in grouped_examples.items(): + # cap maximum group size if len(group) >= 10: group = group[:10] @@ -208,7 +256,26 @@ async def train_jsonl_file(file_path, output_path, num_tags): print("M-step for # tags =", len(groups)) - new_tags = await tqdm.gather(*[get_new_tag(group) for group in groups]) + # in normal mode we just sample conditioned on the current group + if mode == "normal": + new_tags = await tqdm.gather(*[get_new_tag(group) for group in groups]) + # in contrastive mode, we sample a negative group the tag should *not* describe + elif mode == "contrastive": + print("Contrastive mode") + args = [] + for i in range(len(groups)): + j = np.random.choice([k for k in range(len(groups)) if k != i]) + args.append((groups[j], groups[i])) + + new_tags = await tqdm.gather( + *[ + get_new_tag_contrastive(group_a, group_b) + for group_a, group_b in args + ] + ) + else: + raise ValueError("Invalid mode:", mode) + new_tags_dict = dict(zip(keys, new_tags)) new_tags = [] @@ -231,7 +298,7 @@ async def train_jsonl_file(file_path, output_path, num_tags): ofile.close() -async def infer_jsonl_file(file_path, tags_file, output_path): +async def infer_jsonl_file(file_path, tags_file, output_path, num_inferences=-1): tags = None end = False batch_size = 100 @@ -256,9 +323,10 @@ async def infer_jsonl_file(file_path, tags_file, output_path): except: resume_from = 0 + done = 0 ofile = open(output_path, "w") progress_bar = ttqdm(total=len(train_examples) + 100_000) - while not end: + while not end and (num_inferences == -1 or done < num_inferences): batch = [] metadata = [] try: @@ -296,42 +364,52 @@ async def infer_jsonl_file(file_path, tags_file, output_path): # if it is the last message for this example, write it to the file if is_last: ofile.write(train_examples[metadata[i][0]] + "\n") + done += 1 progress_bar.update(len(batch)) print("Random tags:", n_random) ofile.close() -async def train_(json_file_path, num_tags=100): +async def train_(json_file_path, output_path, num_tags=100, mode="normal"): await train_jsonl_file( json_file_path, - json_file_path.replace(".jsonl", "") + f"_{gpt_model}_tags-{num_tags}.jsonl", + output_path, num_tags=num_tags, + mode=mode, + ) + await infer_( + json_file_path, + output_path, + num_inferences=100_000, ) -async def infer_(json_file_path, tags_file): +async def infer_(json_file_path, tags_file, num_inferences=-1): await infer_jsonl_file( json_file_path, tags_file, - tags_file.replace(".jsonl", "") + f"_inferred.jsonl", + tags_file.replace(".jsonl", "") + f"_inferred_{num_inferences}.jsonl", + num_inferences=num_inferences, ) class GPT4EMTagging: - def infer(self, tags_path, file_path, model="gpt-4o-gs"): + def infer(self, tags_path, file_path, model="gpt-4o-gs", num_inferences=-1): global gpt_model gpt_model = model print("Working on...", file_path) - asyncio.run(infer_(file_path, tags_path)) + asyncio.run(infer_(file_path, tags_path, num_inferences=num_inferences)) - def train(self, file_path, num_tags=100, model="gpt-4o-gs"): + def train( + self, file_path, output_path, num_tags=100, model="gpt-4o-gs", mode="normal" + ): global gpt_model gpt_model = model print("Working on...", file_path) - asyncio.run(train_(file_path, num_tags=num_tags)) + asyncio.run(train_(file_path, output_path, num_tags=num_tags, mode=mode)) if __name__ == "__main__": From c55fe170aa60401c4b9912a0206d0a91c6376def Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 9 Aug 2024 06:12:35 -0700 Subject: [PATCH 26/31] minimal template changes --- projects/modular_llm/gpt4_tagging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index b59b52d7f..aabdecc14 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -35,8 +35,8 @@ m_contrastive_template = """ -The following two groups of instructions A and B, and each is associated with a specific tag. Your task is to create a precise and descriptive title for the tag of group B, encapsulating a common aspect found in most of these instructions. -The title should describe the instructions in group B in a way that contrasts with common aspects of instructions of the group A. +The following two groups of instructions A and B are each associated with a specific tag. Your task is to create a precise and descriptive title for the tag of group B, encapsulating a common aspect found in most of these instructions. +The title should describe the instructions in group B in a way that contrasts with common aspects of the instructions in group A. {% for instruction in instructions_a %} Instruction (A): @@ -50,7 +50,7 @@ {% endfor %} -Determine a better title for the instructions in group B that encapsulates a common aspect found in most of these instructions while being different from the instructions of group A, please provide it in this format: +Determine a better title for the instructions in group B that encapsulates a common aspect found in most of these instructions while not encapsulating common aspects of the instructions in group A, please provide it in this format: Tag: Descriptive title for the tag """ From 6df090a2b94ca2a08380c6d30f073620c211ed87 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 9 Aug 2024 09:55:04 -0700 Subject: [PATCH 27/31] gpt tagging --- projects/modular_llm/gpt4_tagging.py | 33 ++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index aabdecc14..0393bce88 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -16,7 +16,10 @@ azure_endpoint="https://gcrgpt4aoai7.openai.azure.com/", api_version="2024-05-01-preview", ) -gpt_model = "gpt-4o-gs" + + +gpt_model = None +g_infer_model = "gpt-4o-mini" m_template = """ @@ -39,13 +42,13 @@ The title should describe the instructions in group B in a way that contrasts with common aspects of the instructions in group A. {% for instruction in instructions_a %} -Instruction (A): +Instruction (Group A): {{instruction}} {% endfor %} {% for instruction in instructions_b %} -Instruction (B): +Instruction (Group B): {{instruction}} {% endfor %} @@ -120,7 +123,7 @@ async def get_new_tag_contrastive(instructions_a, instructions_b): async def assign_tag(instruction, tags): - if len(instruction.split()) >= 10_000: + if len(instruction.split()) >= 50_000: print("too long instruction, skipping.") return None @@ -372,16 +375,22 @@ async def infer_jsonl_file(file_path, tags_file, output_path, num_inferences=-1) async def train_(json_file_path, output_path, num_tags=100, mode="normal"): + global gpt_model + global g_infer_model + await train_jsonl_file( json_file_path, output_path, num_tags=num_tags, mode=mode, ) + + # switch inference mode + gpt_model = g_infer_model await infer_( json_file_path, output_path, - num_inferences=100_000, + num_inferences=10_000, ) @@ -395,19 +404,29 @@ async def infer_(json_file_path, tags_file, num_inferences=-1): class GPT4EMTagging: - def infer(self, tags_path, file_path, model="gpt-4o-gs", num_inferences=-1): + def infer(self, tags_path, file_path, model="gpt-4o-mini", num_inferences=-1): global gpt_model gpt_model = model + print("Working on...", file_path) asyncio.run(infer_(file_path, tags_path, num_inferences=num_inferences)) def train( - self, file_path, output_path, num_tags=100, model="gpt-4o-gs", mode="normal" + self, + file_path, + output_path, + num_tags=100, + model="gpt-4o-gs", + infer_model="gpt-4o-mini", + mode="normal", ): global gpt_model + global g_infer_model gpt_model = model + g_infer_model = infer_model + print("Working on...", file_path) asyncio.run(train_(file_path, output_path, num_tags=num_tags, mode=mode)) From 09a2befafd9d0a3f47393ed13929e94f437881d3 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 9 Aug 2024 10:19:29 -0700 Subject: [PATCH 28/31] infer/train model --- projects/modular_llm/gpt4_tagging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_llm/gpt4_tagging.py index 0393bce88..588179ea8 100644 --- a/projects/modular_llm/gpt4_tagging.py +++ b/projects/modular_llm/gpt4_tagging.py @@ -398,7 +398,8 @@ async def infer_(json_file_path, tags_file, num_inferences=-1): await infer_jsonl_file( json_file_path, tags_file, - tags_file.replace(".jsonl", "") + f"_inferred_{num_inferences}.jsonl", + tags_file.replace(".jsonl", "") + + f"_inferred_{gpt_model}_n{num_inferences}.jsonl", num_inferences=num_inferences, ) From d803dae0f7e27dade3919d4e8b28ceb8cdce7d1f Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Wed, 14 Aug 2024 14:50:30 -0700 Subject: [PATCH 29/31] move stuff into modular-chatbot --- .../chatbot/mistral_7b.json | 0 .../get_clusters.py | 0 .../gpt4_tagging.py | 0 .../jsonl_to_hf_chat_dataset.py | 46 +++- projects/modular_chatbot/train_experts.py | 197 ++++++++++++++++++ 5 files changed, 236 insertions(+), 7 deletions(-) rename projects/{modular_llm/configs => modular_chatbot}/chatbot/mistral_7b.json (100%) rename projects/{modular_llm => modular_chatbot}/get_clusters.py (100%) rename projects/{modular_llm => modular_chatbot}/gpt4_tagging.py (100%) rename {mttl/cli => projects/modular_chatbot}/jsonl_to_hf_chat_dataset.py (61%) create mode 100644 projects/modular_chatbot/train_experts.py diff --git a/projects/modular_llm/configs/chatbot/mistral_7b.json b/projects/modular_chatbot/chatbot/mistral_7b.json similarity index 100% rename from projects/modular_llm/configs/chatbot/mistral_7b.json rename to projects/modular_chatbot/chatbot/mistral_7b.json diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_chatbot/get_clusters.py similarity index 100% rename from projects/modular_llm/get_clusters.py rename to projects/modular_chatbot/get_clusters.py diff --git a/projects/modular_llm/gpt4_tagging.py b/projects/modular_chatbot/gpt4_tagging.py similarity index 100% rename from projects/modular_llm/gpt4_tagging.py rename to projects/modular_chatbot/gpt4_tagging.py diff --git a/mttl/cli/jsonl_to_hf_chat_dataset.py b/projects/modular_chatbot/jsonl_to_hf_chat_dataset.py similarity index 61% rename from mttl/cli/jsonl_to_hf_chat_dataset.py rename to projects/modular_chatbot/jsonl_to_hf_chat_dataset.py index 88036c1ed..7ff298514 100644 --- a/mttl/cli/jsonl_to_hf_chat_dataset.py +++ b/projects/modular_chatbot/jsonl_to_hf_chat_dataset.py @@ -6,6 +6,21 @@ from transformers import AutoTokenizer +def normalize_tag(tag): + tag = tag.lower() + tag = tag.replace(" ", "_") + return tag + + +def custom_transform(messages): + text = "" + for idx, message in enumerate(messages): + assert message["role"] in ["system", "user", "assistant"] + + text += f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" + return text + + @click.command() @click.option("--input_jsonl", help="Path to the input jsonl file") @click.option("--model", help="Model name or path to the model") @@ -14,13 +29,14 @@ help="Path to the output hf dataset. Same as input file but with no extension if not provided.", default=None, ) -def main(input_jsonl, model, output_dataset): +def main(input_jsonl, model=None, output_dataset=None): if output_dataset is None: output_dataset, _ = os.path.splitext(input_jsonl) num_proc = os.environ.get("MTTL_NUM_PROC_DATASETS", 16) - tokenizer = AutoTokenizer.from_pretrained(model) + if model is not None: + tokenizer = AutoTokenizer.from_pretrained(model) def apply_chat_template(example): return tokenizer.apply_chat_template( @@ -36,26 +52,42 @@ def chat_progression(examples): targets = [] task_names = [] num_rounds = [] + last_tags = [] + last_tag = "[EMPTY]" + for messages, metadata in zip(examples["messages"], examples["metadata"]): - messages = json.loads(messages) task_name = json.loads(metadata or "{}").get("task_name", "unknown") - chat_progression = [] rounds = 1 + chat_progression = [] for message in messages: if message["role"] != "assistant": chat_progression.append(message) + # if this message has been tagged, then update the last tag and consider that tag for the next assistant message + if message["role"] == "user" and "tag" in message: + last_tag = normalize_tag(message["tag"]) + else: + last_tag = "[EMPTY]" else: - sources.append(apply_chat_template(list(chat_progression))) - targets.append(apply_chat_template([dict(message)])) + if len(chat_progression) == 1: + continue + + func = ( + apply_chat_template if model is not None else custom_transform + ) + sources.append(func(list(chat_progression))) + targets.append(func([message])) task_names.append(task_name) num_rounds.append(rounds) chat_progression.append(message) + last_tags.append(last_tag) rounds += 1 + return { "source": sources, "target": targets, "task_name": task_names, "round": num_rounds, + "tag": last_tags, } dataset = load_dataset("json", data_files=input_jsonl) @@ -64,7 +96,7 @@ def chat_progression(examples): chat_progression, batched=True, # allows to return more examples than the input remove_columns=dataset["train"].column_names, - num_proc=num_proc, + num_proc=1, ) dataset.save_to_disk(output_dataset) diff --git a/projects/modular_chatbot/train_experts.py b/projects/modular_chatbot/train_experts.py new file mode 100644 index 000000000..43098cf27 --- /dev/null +++ b/projects/modular_chatbot/train_experts.py @@ -0,0 +1,197 @@ +import os +import shutil +import sys +from tempfile import TemporaryDirectory +from typing import Type + +import torch +from pytorch_lightning import Trainer, seed_everything + +from mttl.callbacks import ( + DownstreamEvalCallback, + LiveCheckpointCallback, + NanoMMLUCallback, + RougeCallback, +) +from mttl.config import Args, ExpertConfig +from mttl.datamodule.base import get_datamodule +from mttl.logging import get_pl_loggers, logger, setup_logging +from mttl.models.expert_model import ExpertModel, MoEModel +from mttl.models.library.expert import Expert, load_expert +from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary +from mttl.models.monitors import get_monitors +from mttl.utils import generate_random_string, rank_zero_only_and_wait, remote_login + + +def train_experts(args: Args, model_class: Type[ExpertModel]): + seed_everything(args.seed, workers=True) + + # get directory of the current file + setup_logging(args.output_dir) + + logger.info("Args: {}".format(args.to_json())) + + remote_login(args.remote_token) + expert_library = None + if args.library_id: + + @rank_zero_only_and_wait(before=False, after=True) + def create_library(args): + expert_library = ExpertLibrary.get_expert_library( + repo_id=args.library_id, + create=True, + destination_id=args.destination_library_id, + ) + return expert_library + + expert_library = create_library(args) + + loggers = get_pl_loggers(args) + + dm = get_datamodule(args) + args.n_tasks = len(dm._task_names) + args.task_names = dm._task_names + + module = model_class(**args.asdict()) + + # get metric monitors for models + callbacks = get_monitors(args) + if "mbpp" in args.dataset: + monitor = "downstream/mbpp" + mode = "max" + else: + monitor = "val/loss" + mode = "min" + + checkpoint_callback = LiveCheckpointCallback( + dirpath=args.output_dir, + monitor=monitor, + save_last=True, + mode=mode, + save_each_epoch=args.save_each_epoch, + ) + callbacks.append(checkpoint_callback) + + if args.eval_rouge_flag: + rouge = RougeCallback( + get_datamodule(args, for_generation=True), + every_n_epochs=3 if args.num_train_epochs > 5 else 1, + ) + callbacks.append(rouge) + else: + logger.warning( + "Deactivating rouge callback as it is not enabled in the config. Please set `eval_rouge_flag=True`." + ) + + if args.eval_mmlu_flag: + mmlu = NanoMMLUCallback( + get_datamodule(args, dataset_override="mmlu", for_generation=True), + every_n_epochs=3 if args.num_train_epochs > 3 else 1, + ) + callbacks.append(mmlu) + else: + logger.warning( + "Deactivating mmlu callback as it is not enabled in the config. Please set `eval_mmlu_flag=True`." + ) + + if args.pipeline_eval_tasks: + if args.pipeline_eval_tasks == "all": + args.pipeline_eval_tasks = "arc-challenge,arc-easy,boolq,hellaswag,humaneval,mbpp,openbookqa,piqa,bbh-fast,winogrande" + + eval = DownstreamEvalCallback(args) + callbacks.append(eval) + else: + logger.warning( + "Deactivating downstream eval callback as it is not enabled in the config. Please set `pipeline_eval_tasks`." + ) + + val_check_interval = args.eval_every + if val_check_interval == -1 or val_check_interval is None: + val_check_interval = None + elif not (0.0 < val_check_interval < 1.0): + val_check_interval = args.gradient_accumulation_steps * args.eval_every + if val_check_interval > len(dm.train_dataloader()): + val_check_interval = len(dm.train_dataloader()) + elif val_check_interval > args.total_steps and args.total_steps != -1: + val_check_interval = args.total_steps + + trainer = Trainer( + devices=-1, + accelerator="gpu", + logger=loggers, + num_sanity_val_steps=0, + default_root_dir=args.output_dir, + max_epochs=args.num_train_epochs, + max_steps=args.total_steps + 1 if args.total_steps != -1 else -1, + gradient_clip_val=args.max_grad_norm, + strategy=args.compute_strategy if args.compute_strategy else "auto", + callbacks=callbacks, + enable_checkpointing=False, + log_every_n_steps=args.gradient_accumulation_steps, + accumulate_grad_batches=args.gradient_accumulation_steps, + precision=( + int(args.precision) if args.precision in ["16", "32"] else args.precision + ), + val_check_interval=val_check_interval, + ) + + # initial validation only for a bunch of datasets... ? + if args.eval_before_training: + # validating before training fails with deepspeed + trainer.validate(module, dm) + + if args.do_train: + trainer.fit(module, dm) + + torch.cuda.empty_cache() + + # reload best model before pushing! + checkpoint = ( + checkpoint_callback.best_model_path or checkpoint_callback.last_model_path + ) + if args.compute_strategy == "deepspeed": + from deepspeed.utils.zero_to_fp32 import ( + convert_zero_checkpoint_to_fp32_state_dict, + ) + + new_path = checkpoint.replace(".ckpt", "_fp32.ckpt") + + @rank_zero_only_and_wait(before=True, after=True) + def convert_ckpt(path, new_path): + convert_zero_checkpoint_to_fp32_state_dict(path, new_path) + + convert_ckpt(checkpoint, new_path) + checkpoint = torch.load(new_path) + else: + checkpoint = torch.load(checkpoint)["state_dict"] + + module.load_state_dict(checkpoint) + trainer.test(module, dm) + + @rank_zero_only_and_wait(before=False, after=True) + def upload_library(expert_library, module): + if expert_library is not None: + # refresh expert library: so we dont overwrite the readme if the remote has changed. + expert_library.refresh_from_remote() + + if isinstance(module, MoEModel): + with expert_library.batched_commit(): + for expert_name in module.experts_names: + expert = module.get_expert_instance(expert_name) + expert_library.add_expert(expert, expert_name) + elif isinstance(module, ExpertModel): + expert = module.as_expert() + expert_name = ( + args.expert_name + or args.finetune_task_name + or generate_random_string() + ) + expert_library.add_expert(expert, expert_name) + else: + raise ValueError("Model class not recognized") + + upload_library(expert_library, module) + + +if __name__ == "__main__": + train_experts(ExpertConfig.parse(), ExpertModel) From d68af91483908e6eecd91cb71de4095221797cf5 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 16 Aug 2024 08:10:08 -0700 Subject: [PATCH 30/31] small updates --- mttl/datamodule/base.py | 12 ++++++-- mttl/datamodule/cluster_data_module.py | 28 ------------------- mttl/registrable.py | 9 ++++-- .../{chatbot => configs}/mistral_7b.json | 0 4 files changed, 17 insertions(+), 32 deletions(-) delete mode 100644 mttl/datamodule/cluster_data_module.py rename projects/modular_chatbot/{chatbot => configs}/mistral_7b.json (100%) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index cf8173289..615c48021 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -909,8 +909,8 @@ def collate_fn(self): def get_datamodule(args, for_generation=False, dataset_override=None): + from mttl.config import DataArgs from mttl.datamodule.arc_data_module import ArcDataConfig, ArcMultiChoiceDataModule - from mttl.datamodule.cluster_data_module import ClusterDataConfig, ClusterDataModule from mttl.datamodule.codex_data_module import CodexDataConfig, CodexDataModule from mttl.datamodule.hellaswag_data_module import ( HellaswagDataConfig, @@ -940,7 +940,15 @@ def get_datamodule(args, for_generation=False, dataset_override=None): WinograndeMultiChoiceDataModule, ) - # refactor all the common arguments below into a dict common kwargs + # if we have a DataArgs object, we can directly create the datamodule + if isinstance(args, DataArgs): + dataset_config = args.dataset_config + + return DataModule.get_class_by_config_class(type(dataset_config))( + dataset_config, for_generation=for_generation + ) + + # we fall back to previous behavior dataset = args.dataset if not dataset_override else dataset_override common_kwargs = { diff --git a/mttl/datamodule/cluster_data_module.py b/mttl/datamodule/cluster_data_module.py deleted file mode 100644 index 4ab4f896e..000000000 --- a/mttl/datamodule/cluster_data_module.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from dataclasses import dataclass -from functools import partial - -import numpy - -from mttl.datamodule.base import DataModule, DatasetConfig -from mttl.datamodule.mt_seq_to_seq_module import ( - FlatMultiTaskConfig, - FlatMultiTaskModule, - apply_source_template, -) -from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task -from mttl.logging import logger -from mttl.models.library.expert_library import DatasetLibrary - - -@dataclass -class ClusterDataConfig(FlatMultiTaskConfig): - """Just adapts the FlatMultiTaskConfig to a dataset containing a column with a cluster id.""" - - task_name_field: str = "cluster_id" - task_id_field: str = "cluster_id" - - -@DataModule.register("cluster_flat_multitask", ClusterDataConfig) -class ClusterDataModule(FlatMultiTaskModule): - pass diff --git a/mttl/registrable.py b/mttl/registrable.py index b5b4c5cd1..95d690b19 100644 --- a/mttl/registrable.py +++ b/mttl/registrable.py @@ -23,14 +23,19 @@ def add_to_registry(subclass): return add_to_registry + @classmethod + def from_config(cls, config: Any, **kwargs) -> "Registrable": + klass = cls.get_class_by_config_class(type(config)) + return klass(config, **kwargs) + @classmethod def get_config_class_by_name(cls, name: str) -> Type: - subclass, config_cls = Registrable._registry[cls].get(name) + subclass, config_cls = Registrable._registry[cls][name] return config_cls @classmethod def get_class_by_name(cls, name: str) -> Type: - subclass, config_cls = Registrable._registry[cls].get(name) + subclass, config_cls = Registrable._registry[cls][name] return subclass @classmethod diff --git a/projects/modular_chatbot/chatbot/mistral_7b.json b/projects/modular_chatbot/configs/mistral_7b.json similarity index 100% rename from projects/modular_chatbot/chatbot/mistral_7b.json rename to projects/modular_chatbot/configs/mistral_7b.json From 8c239bd2fd8d9b0c24b7c558fd269992816f1ca6 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Fri, 16 Aug 2024 08:10:44 -0700 Subject: [PATCH 31/31] load from dataset_type --- mttl/datamodule/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 615c48021..b7ec98a22 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -941,7 +941,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None): ) # if we have a DataArgs object, we can directly create the datamodule - if isinstance(args, DataArgs): + if isinstance(args, DataArgs) and args.dataset_type is not None: dataset_config = args.dataset_config return DataModule.get_class_by_config_class(type(dataset_config))(