Skip to content

Commit

Permalink
Code restructuring 1 (#66)
Browse files Browse the repository at this point in the history
* removed duplicated get_datamodule

* Clean utils before merging them

* Merged proj.utils into mttl.utils, moved logger to mttl.logging

* move warn_once to mttl.logging

* Moved TableLogger to mttl.logging

* Moved init_wandb_logger to mttl.logger

* Moved pl loggers to mttl.logging

* Moved get_task_expert to mttl.models.library.expert_library

* "fix" circular dependency

* Moved proj.evaluators to mttl.evaluators

* removed unused get_loss function

* moved get_svd_embedding to mttl.models.library.utils
  • Loading branch information
matheper authored Jul 23, 2024
1 parent 68a910f commit 0c851e0
Show file tree
Hide file tree
Showing 60 changed files with 378 additions and 873 deletions.
98 changes: 96 additions & 2 deletions mttl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@
import os
import shutil
import sys
from abc import ABC, abstractmethod

import pytorch_lightning as pl
import torch
import tqdm
import wandb
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import callbacks as cb
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from torch.optim import Optimizer

from mttl.datamodule.base import DefaultDataModule
from mttl.evaluators import MMLUEvaluator
from mttl.evaluators.evaluators import Evaluator
from mttl.logging import logger
from mttl.models.utils import transfer_batch_to_device
from mttl.utils import logger

DEBUG = False

Expand Down Expand Up @@ -298,7 +303,7 @@ def __init__(
):
super().__init__()

from mttl.evaluators.mmlu_evaluator import MMLUEvaluator
from mttl.evaluators import MMLUEvaluator

self.evaluator = MMLUEvaluator(datamodule=datamodule)
self.every_n_epochs = every_n_epochs
Expand Down Expand Up @@ -579,3 +584,92 @@ def init_test_tqdm(self) -> Tqdm:
file=sys.stderr,
)
return bar


class EvalCallback(ABC):
@abstractmethod
def evaluate_model(self, model, prefix=""):
pass


class MMLUEvalCallback(MMLUEvaluator, EvalCallback):
def __init__(
self,
config,
name="mmlu_test_callback",
split="test",
subsample=-1,
use_vllm=False,
):
self.split = split
from mttl.datamodule.mmlu_data_module import MMLUDataConfig

assert split in ["test"]
self.use_vllm = use_vllm
mmlu_config = MMLUDataConfig(
**{
k: v
for k, v in config.__dict__.items()
if k in MMLUDataConfig.__dataclass_fields__.keys()
}
)
super().__init__(mmlu_config, use_vllm=use_vllm)
self.subsample = subsample
self.name = name

def evaluate_model(self, model, prefix=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
score = self.evaluate(model, self.split, self.subsample)["all"]["mean"]
# log
if wandb.run is not None:
wandb.log({f"{prefix}{self.name}_{self.split}": score})
return score


class TestLossEvaluator(LossCallback, Evaluator):
def __init__(
self,
datamodule: DefaultDataModule,
name="test",
split="test",
subsample=-1,
**kwargs,
):
self.split = split
if split == "test":
dataloader = datamodule.test_dataloader(subsample=subsample)
elif split in ["val", "valid", "validation"]:
dataloader = datamodule.val_dataloader(subsample=subsample)
elif split == "train":
dataloader = datamodule.train_dataloader(subsample=subsample)
super().__init__(
dataloader=dataloader,
name=name,
output_dir=None,
eval_every_opt_step=0,
checkpoint_oracle=False,
)
self.datamodule = datamodule

@property
def tasks(self):
return self.datamodule.task_names

def evaluate(self, model):
# return something that should be maximized
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
loss = self.test(model)
loss *= -1.0
return {"all": {"mean": loss.item()}, f"{self.name}": {"mean": loss.item()}}

def get_loss(self, model, **kwargs):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
loss = self.test(model, **kwargs)
return loss.item()

@property
def tokenizer(self):
return self.datamodule.tokenizer
1 change: 0 additions & 1 deletion mttl/cli/convert_library_to_hf_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from mttl.models.library.expert_library import ExpertLibrary
from mttl.utils import logger


def translate_lib_to_hf_phi(
Expand Down
2 changes: 1 addition & 1 deletion mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from string import Template
from typing import Dict

from mttl.utils import logger, setup_logging
from mttl.logging import logger, setup_logging


class Config:
Expand Down
3 changes: 1 addition & 2 deletions mttl/dataloader/mmlu_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""MMLU Dataset."""

import copy
import os

import datasets
import pandas as pd

from mttl.utils import logger
from mttl.logging import logger

_CITATION = """\
@article{hendryckstest2021,
Expand Down
2 changes: 1 addition & 1 deletion mttl/dataloader/ni_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import datasets

from mttl.utils import logger
from mttl.logging import logger

_CITATION = """
@article{wang2022benchmarking,
Expand Down
3 changes: 0 additions & 3 deletions mttl/dataloader/ni_metrics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import argparse
import json
import logging
import os
import string

import numpy as np
from torchmetrics.text.rouge import ROUGEScore
from transformers import AutoTokenizer

from mttl.utils import logger


class GPTTokenizer:
def __init__(self):
Expand Down
3 changes: 2 additions & 1 deletion mttl/dataloader/oasst1_readers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from mttl.logging import logger
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import hash_example, logger
from mttl.utils import hash_example


class Oasst1Template:
Expand Down
4 changes: 1 addition & 3 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.tokenization_utils_base import PaddingStrategy

from mttl.datamodule.utils import get_tokenizer
from mttl.utils import logger
from mttl.logging import logger


@dataclass
Expand Down Expand Up @@ -442,8 +442,6 @@ def collate_fn(self):
)

def print_infos(self):
from mttl.utils import logger

logger.info("Dataset name: %s", self.config.dataset)
logger.info("Reader class: %s", self.__class__.__name__)
if self.train_dataset is not None and len(self.train_dataset) > 0:
Expand Down
5 changes: 1 addition & 4 deletions mttl/datamodule/facts_lm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from mttl.datamodule.base import DefaultDataModule
from mttl.datamodule.platypus_module import PlatypusConfig
from mttl.logging import logger, setup_logging
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


@dataclass
Expand Down Expand Up @@ -92,10 +92,7 @@ def form_dataset(data):


if __name__ == "__main__":
import os

from mttl.config import Config
from mttl.utils import setup_logging

setup_logging()

Expand Down
3 changes: 2 additions & 1 deletion mttl/datamodule/mt_seq_to_seq_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from datasets import Dataset, concatenate_datasets

from mttl.datamodule.base import DatasetConfig, DefaultDataModule
from mttl.datamodule.utils import logger, maybe_filter_hf_dataset_by_task
from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task
from mttl.logging import logger
from mttl.models.library.expert_library import DatasetLibrary


Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/ni_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task
from mttl.logging import logger
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


@dataclass
Expand Down
5 changes: 1 addition & 4 deletions mttl/datamodule/retrieval_lm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.tokenization_utils_base import PaddingStrategy

from mttl.datamodule.utils import get_tokenizer
from mttl.logging import logger, setup_logging
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


@dataclass
Expand Down Expand Up @@ -211,10 +211,7 @@ def setup_dataset(self, stage=None):


if __name__ == "__main__":
import os

from mttl.config import Config
from mttl.utils import setup_logging

setup_logging()

Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import AutoTokenizer, LlamaTokenizer

from mttl.utils import logger
from mttl.logging import logger


def maybe_filter_hf_dataset_by_task(
Expand Down
4 changes: 1 addition & 3 deletions mttl/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

from mttl.logging import logger
from mttl.models.utils import EfficientCheckpointModule, transfer_batch_to_device
from mttl.utils import logger


def decode(preds, tokenizer, clean_up_tokenization_spaces=True):
Expand Down Expand Up @@ -378,8 +378,6 @@ def run(self, module, verbose=False):

import prettytable

from mttl.utils import logger

if self.output_path:
os.makedirs(self.output_path, exist_ok=True)

Expand Down
3 changes: 1 addition & 2 deletions mttl/evaluators/code_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import os

import tqdm
from evaluate import load

from mttl.evaluators.base import GenerativeEvaluator, switch_to_eval_mode
from mttl.utils import logger
from mttl.logging import logger


# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35
Expand Down
2 changes: 1 addition & 1 deletion mttl/evaluators/em_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
GenerativeEvaluator,
switch_to_eval_mode,
)
from mttl.utils import logger
from mttl.logging import logger


class EMEvaluator(GenerativeEvaluator):
Expand Down
Loading

0 comments on commit 0c851e0

Please sign in to comment.