Skip to content

Commit

Permalink
raise if dataset_type is not set and dataset_config is called
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Aug 12, 2024
1 parent 81c08e3 commit a4073b7
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ class TransformArgs(

@dataclass
class TrainingArgs(DataArgs):
# model arguments
model: str = None
model_family: str = None
attn_implementation: str = None
device_map: str = "cpu"
load_in_4bit: bool = False
load_in_8bit: bool = False
do_train: bool = True

# output directories
cache_dir: str = os.getenv("CACHE_DIR", "./cache")
data_dir: str = os.getenv("TRAIN_DIR", "/tmp/")
output_dir: str = os.getenv("OUTPUT_DIR", "./output")
Expand Down Expand Up @@ -359,14 +369,6 @@ class TrainingArgs(DataArgs):
precision: str = "32"
monitor_grad_alignment_on: str = None

model: str = None
model_family: str = None
attn_implementation: str = None
device_map: str = "cpu"
load_in_4bit: bool = False
load_in_8bit: bool = False
do_train: bool = True

# logging
wandb_project: str = None
tensorboard: bool = False
Expand Down Expand Up @@ -402,6 +404,10 @@ def dataset_config(self):
DataArgs.registrable_class.get_config_class_by_name(self.dataset_type),
self,
)
else:
raise ValueError(
"Trying to access dataset config without specifying `dataset_type`!"
)

def __post_init__(self):
if self.attn_implementation == "eager" and self.pack_sequences:
Expand Down Expand Up @@ -430,6 +436,18 @@ def __post_init__(self):
+ "into account when computing `gradient_accumulation_steps`."
)

if self.model_family is None:
# infer model family automatically
if "t5" in self.model or "T0" in self.model:
self.model_family = "seq2seq"
else:
self.model_family = "gpt"

logger.warn(
"Model family was not specified, inferring from model name:",
self.model_family,
)


@dataclass
class ExpertConfig(TrainingArgs, ModifierArgs):
Expand Down Expand Up @@ -515,3 +533,14 @@ class MoEExpertConfig(MultiExpertConfig):
moe_ent_reg: float = 0.0
moe_ent_free_bits: float = 0.0
moe_num_experts: int = 8


@dataclass
class RankerConfig(TrainingArgs, SelectorArgs):
encoder_model_name: str = "all-MiniLM-L6-v2"
text_embedding_dim: int = 384
expert_embedding_dim: int = 512
projection_dim: int = 512
val_check_interval = 1.0
limit_val_batches: float = 1.0
limit_train_batches: float = 1.0

0 comments on commit a4073b7

Please sign in to comment.