Skip to content

Commit

Permalink
Merge pull request #85 from microsoft/move-files
Browse files Browse the repository at this point in the history
CLEAN: Move all the functions from projects/modular_llm into mttl
  • Loading branch information
sordonia authored Aug 14, 2024
2 parents 6b83bb9 + 3c19476 commit 5d9c009
Show file tree
Hide file tree
Showing 38 changed files with 452 additions and 2,536 deletions.
26 changes: 25 additions & 1 deletion mttl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
#
from mttl.models.packed_attention_monkey_patch import (
flash_attn_func_wrapper,
flash_attn_varlen_func_wrapper,
scaled_dot_product_attention,
)

try:
import flash_attn
from flash_attn import flash_attn_func, flash_attn_varlen_func

flash_attn._default_flash_attn_func = flash_attn_func
flash_attn._default_flash_attn_varlen_func = flash_attn_varlen_func
flash_attn.flash_attn_varlen_func = flash_attn_varlen_func_wrapper
flash_attn.flash_attn_func = flash_attn_func_wrapper
except ImportError:
from mttl.logging import logger

logger.info("Flash Attention not available")

import torch

torch.nn.functional._default_scaled_dot_product_attention = (
torch.nn.functional.scaled_dot_product_attention
)
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
54 changes: 54 additions & 0 deletions mttl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from mttl.datamodule.base import DataModule
from mttl.evaluators import MMLUEvaluator
from mttl.evaluators.base import EvaluatorRunner, setup_evaluators
from mttl.evaluators.evaluators import Evaluator
from mttl.logging import logger
from mttl.models.utils import transfer_batch_to_device
Expand Down Expand Up @@ -673,3 +674,56 @@ def get_loss(self, model, **kwargs):
@property
def tokenizer(self):
return self.datamodule.tokenizer


class DownstreamEvalCallback(cb.Callback):
METRIC_KEY = "downstream"

def __init__(self, args) -> None:
super().__init__()

self.args = args
self.runner: EvaluatorRunner = setup_evaluators(
model_type=args.model,
model_family=args.model_family,
max_input_length=args.max_input_length,
max_output_length=args.max_output_length,
predict_batch_size=args.predict_batch_size,
truncation_side=args.truncation_side,
tasks=args.pipeline_eval_tasks,
output_path=os.path.join(args.output_dir, self.METRIC_KEY),
add_eos_to_targets=args.add_eos_to_downstream_targets,
)

def on_validation_epoch_start(
self, trainer: Trainer, pl_module: pl.LightningModule
) -> None:
if trainer.global_step == 0 and not self.args.eval_before_training:
return

if self.args.eval_every_n_epoch is None or (
self.args.eval_every_n_epoch
and trainer.current_epoch % self.args.eval_every_n_epoch != 0
):
return

metrics = self.runner.run(pl_module)
for task, metric in metrics.items():
pl_module.log(
f"{self.METRIC_KEY}/{task}",
metric,
on_epoch=True,
prog_bar=True,
)

def on_test_epoch_end(
self, trainer: Trainer, pl_module: pl.LightningModule
) -> None:
metrics = self.runner.run(pl_module)
for task, metric in metrics.items():
pl_module.log(
f"{self.METRIC_KEY}_last/{task}",
metric,
on_epoch=True,
prog_bar=True,
)
File renamed without changes.
14 changes: 12 additions & 2 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ class TrainingArgs(DataArgs):
seed: int = 42
debug: bool = False

eval_before_training: bool = True
precision: str = "32"
monitor_grad_alignment_on: str = None

Expand Down Expand Up @@ -401,6 +400,7 @@ class TrainingArgs(DataArgs):
eval_mmlu_flag: bool = False # eval mmlu performance during training
eval_rouge_flag: bool = False # eval rouge during training
eval_before_training: bool = True
create_transfer_matrix: bool = False
pipeline_eval_tasks: str = None
save_if_loaded_from_ckpt: bool = True
dataset_type: str = None
Expand All @@ -418,6 +418,16 @@ def dataset_config(self):
)

def __post_init__(self):
if self.model is not None and self.model_family is None:
# attempt to infer the model family from the model name
if "t5" in self.model or "T0" in self.model:
self.model_family = "seq2seq"
else:
self.model_family = "gpt"
logger.warning(
"Model family not specified, assuming {}".format(self.model_family)
)

if self.attn_implementation == "eager" and self.pack_sequences:
logger.warning(
"Eager attention is not compatible with packed sequences"
Expand Down Expand Up @@ -518,7 +528,6 @@ class EvaluationConfig(MultiExpertConfig, TransformArgs):
merge_or_route: str = None # "uniform", "ties", "clown"
tasksets_path: str = None
remove_experts: str = None
create_transfer_matrix: bool = False
es_metric: str = "loss"
n_ng_iterations: int = 30 # number of iterations for LoraHub
recompute_prototypes: bool = False
Expand All @@ -529,6 +538,7 @@ class MoEExpertConfig(MultiExpertConfig):
moe_ent_reg: float = 0.0
moe_ent_free_bits: float = 0.0
moe_num_experts: int = 8
init_from_scratch: bool = True


@dataclass
Expand Down
89 changes: 89 additions & 0 deletions mttl/dataloader/flan_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from functools import partial

import numpy as np
from datasets import Dataset, concatenate_datasets

from mttl.models.library.expert_library import DatasetLibrary


def encode_with_messages_format(example):
message_text = ""
intruction = "<|user|>\n" + example["user"].strip() + "\n"
Expand All @@ -13,3 +21,84 @@ def encode_with_messages_format(example):
"full_text": message_text,
"hash_text": message_text,
}


def gen_from_iterable_dataset(iterable_ds):
yield from iterable_ds


def download_flan(split="train", download_size=-1, cutoff=10_000, verbose=True):
dataset_name = "chiayewken/flan-v2"

if download_size <= 0:
dataset = DatasetLibrary.pull_dataset(dataset_name, split=split)
else:
iter_ds = DatasetLibrary.pull_dataset(
dataset_name, split=split, streaming=True
).take(download_size)
dataset = Dataset.from_generator(
partial(gen_from_iterable_dataset, iter_ds), features=iter_ds.features
)

# group the dataset using the task_name
task_names = dataset.unique("task_name")
print("Num Tasks: ", len(task_names))

all_datasets = []
for task_name in task_names:
print("Processing task: ", task_name)

task_dataset = dataset.filter(
lambda x: x["task_name"] == task_name, num_proc=24
)

# if the dataset is too large, we randomly sample "cutoff" examples for training
task_dataset = task_dataset.shuffle(42)

if cutoff > 0 and len(task_dataset) > cutoff:
task_dataset = task_dataset.select(range(cutoff))

def assign_split(example, idx):
rng = np.random.RandomState(idx)
draw = rng.rand()
if draw < 0.8:
return {"split": "train"}
elif draw < 0.9:
return {"split": "validation"}
else:
return {"split": "test"}

task_dataset = task_dataset.map(assign_split, with_indices=True)
# randomly cut the dataset again
task_dataset = task_dataset.shuffle(42)

if cutoff and len(task_dataset) > cutoff:
task_dataset = task_dataset.select(range(cutoff))

all_datasets.append(task_dataset)

print("Dumping task", task_name)
if verbose:
print("# Train", len(task_dataset.filter(lambda x: x["split"] == "train")))
print("# Test", len(task_dataset.filter(lambda x: x["split"] == "test")))
print(
"# Val", len(task_dataset.filter(lambda x: x["split"] == "validation"))
)

concatenated_datasets = concatenate_datasets(all_datasets)

def clean_task(x):
if "task_name" not in x:
return x

x["task_name"] = (
x["task_name"]
.replace(":", "_")
.replace("/", "_")
.replace("-", "_")
.replace(".", "_")
)
return x

concatenated_datasets = concatenated_datasets.map(lambda x: clean_task(x))
return concatenated_datasets
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def maybe_filter_hf_dataset_by_key(dataset, key, task_names: str = None, n_proc=
return task_names, task_to_id, train_dataset, dev_dataset, test_dataset


@dataclass
@DataModule.register("clip_experts", CLIPExpertsConfig)
class CLIPExpertsDatamodule(DataModule):
# The dataset format is [x, E, accuracy]
DATA_ENV = "CLIP_DATA_DIR"
Expand Down Expand Up @@ -158,7 +158,7 @@ def collate_fn(self):
)


@dataclass
@DataModule.register("clip_triple", CLIPExpertsConfig)
class CLIPTripleDataModule(DataModule):
# the dataset format is [task_eval, input x, positive_experts, negative_experts]
DATA_ENV = "CLIP_TRIPLE_DATA_DIR"
Expand Down
4 changes: 4 additions & 0 deletions mttl/datamodule/mt_seq_to_seq_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ def filter_task_source(include_task_source, example):
@DataModule.register("flan", config_cls=FlanConfig)
class FlanModule(DataModule):
def setup_dataset(self):
if self.config.dataset is None:
raise ValueError("Please specify a flan dataset to load.")

dataset = DatasetLibrary.pull_dataset_with_retry(self.config.dataset)
n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16))

if "split" not in dataset.column_names["train"]:
raise ValueError(
"Dataset must have a 'split' column, try removing the dataset manually from the cache."
Expand Down
17 changes: 5 additions & 12 deletions mttl/datamodule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,9 @@ def get_tokenizer_with_args(
truncation_side="right",
for_generation=False,
):
if model_family is None:
raise ValueError("model_family is None, please fix your config!")

if "llama-2" in model_name:
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
if not model_family == "gpt":
raise ValueError(
"We detected a Llama model, but model_family != 'gpt', fix your config!"
)
else:
if "phi-2" == model_name:
# local phi-2 version. use `microsoft/phi-2 for the official hf version`
Expand Down Expand Up @@ -150,11 +143,11 @@ def get_tokenizer_with_args(
# do not add eos token, we will add it accordingly *if* needed.
tokenizer.add_eos_token = False

if tokenizer.pad_token_id is None:
logger.warning(
"!!! Setting pad_token_id to eos_token_id, given that pad_token_id was not detected !!!"
)
tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is None:
logger.warning(
"!!! Setting pad_token_id to eos_token_id, given that pad_token_id was not detected !!!"
)
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.mttl_merges_space = tokenizer_merges_space(tokenizer)
tokenizer.mttl_enforces_eos = tokenizer_enforces_eos(tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,12 @@
from mttl.models.expert_model import MultiExpertModel
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import VirtualLocalLibrary
from mttl.models.library.library_transforms import LibraryTransform
from mttl.models.library.library_transforms import (
LibraryTransform,
LibraryTransformConfig,
)
from mttl.models.library.utils import get_svd_embedding

RETRIEVERS = {}


def register_retriever(name):
def decorator(cls):
if name in RETRIEVERS:
raise ValueError(f"Retriever {name} already registered")
RETRIEVERS[name] = cls
return cls

return decorator
from mttl.registrable import Registrable


class Retriever(LibraryTransform):
Expand Down Expand Up @@ -53,7 +45,7 @@ def transform(self, **kwargs):
raise NotImplementedError()


@register_retriever("random")
@LibraryTransform.register("random", LibraryTransformConfig)
class RandomRetriever(Retriever):
def transform(
self, expert_lib, current_task, task_expert: Expert = None, **kwargs
Expand Down Expand Up @@ -106,7 +98,7 @@ def get_lora_task_embeddings(module: MultiExpertModel):
return embeddings


@register_retriever("lora_sim")
@LibraryTransform.register("lora_sim", LibraryTransformConfig)
class LoraSimRetriever(Retriever):
def transform(
self,
Expand Down Expand Up @@ -163,7 +155,7 @@ def transform(
return resulting_library


@register_retriever("svdemb")
@LibraryTransform.register("svdemb", LibraryTransformConfig)
class SVDEmbeddingRetriever(Retriever):
def transform(
self,
Expand Down
Loading

0 comments on commit 5d9c009

Please sign in to comment.