diff --git a/mttl/config.py b/mttl/config.py index e791eb3be..b939f188d 100644 --- a/mttl/config.py +++ b/mttl/config.py @@ -317,6 +317,16 @@ class TransformArgs( @dataclass class TrainingArgs(DataArgs): + # model arguments + model: str = None + model_family: str = None + attn_implementation: str = None + device_map: str = "cpu" + load_in_4bit: bool = False + load_in_8bit: bool = False + do_train: bool = True + + # output directories cache_dir: str = os.getenv("CACHE_DIR", "./cache") data_dir: str = os.getenv("TRAIN_DIR", "/tmp/") output_dir: str = os.getenv("OUTPUT_DIR", "./output") @@ -359,14 +369,6 @@ class TrainingArgs(DataArgs): precision: str = "32" monitor_grad_alignment_on: str = None - model: str = None - model_family: str = None - attn_implementation: str = None - device_map: str = "cpu" - load_in_4bit: bool = False - load_in_8bit: bool = False - do_train: bool = True - # logging wandb_project: str = None tensorboard: bool = False @@ -402,6 +404,10 @@ def dataset_config(self): DataArgs.registrable_class.get_config_class_by_name(self.dataset_type), self, ) + else: + raise ValueError( + "Trying to access dataset config without specifying `dataset_type`!" + ) def __post_init__(self): if self.attn_implementation == "eager" and self.pack_sequences: @@ -430,6 +436,18 @@ def __post_init__(self): + "into account when computing `gradient_accumulation_steps`." ) + if self.model_family is None: + # infer model family automatically + if "t5" in self.model or "T0" in self.model: + self.model_family = "seq2seq" + else: + self.model_family = "gpt" + + logger.warn( + "Model family was not specified, inferring from model name:", + self.model_family, + ) + @dataclass class ExpertConfig(TrainingArgs, ModifierArgs): @@ -515,3 +533,14 @@ class MoEExpertConfig(MultiExpertConfig): moe_ent_reg: float = 0.0 moe_ent_free_bits: float = 0.0 moe_num_experts: int = 8 + + +@dataclass +class RankerConfig(TrainingArgs, SelectorArgs): + encoder_model_name: str = "all-MiniLM-L6-v2" + text_embedding_dim: int = 384 + expert_embedding_dim: int = 512 + projection_dim: int = 512 + val_check_interval = 1.0 + limit_val_batches: float = 1.0 + limit_train_batches: float = 1.0