From 670eaec0477685ca9542ef28a320e116e31d0182 Mon Sep 17 00:00:00 2001 From: Artyom K Date: Mon, 25 Sep 2023 22:45:43 +0300 Subject: [PATCH] Fixes 256 - ValueError mutable default (#263) * Fixes 256 - ValueError mutable default Fixes issue #256. Error message: > ValueError: mutable default for field run_template is not allowed: use default_factory * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * To lowercase --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- elk/training/sweep.py | 10 ++-------- elk/training/train.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/elk/training/sweep.py b/elk/training/sweep.py index a4e5c97ac..c9ca7fc9f 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,4 +1,4 @@ -from dataclasses import InitVar, dataclass, replace +from dataclasses import InitVar, dataclass, field, replace import numpy as np import torch @@ -6,7 +6,6 @@ from transformers import AutoConfig from ..evaluation import Eval -from ..extraction import Extract from ..files import memorably_named_dir, sweeps_dir from ..plotting.visualize import visualize_sweep from ..training.eigen_reporter import EigenFitterConfig @@ -53,12 +52,7 @@ class Sweep: name: str | None = None # A bit of a hack to add all the command line arguments from Elicit - run_template: Elicit = Elicit( - data=Extract( - model="", - datasets=("",), - ) - ) + run_template: Elicit = field(default_factory=Elicit.default) def __post_init__(self, add_pooled: bool): if not self.datasets: diff --git a/elk/training/train.py b/elk/training/train.py index fb8822405..3d00b54c1 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -11,6 +11,7 @@ from simple_parsing import subgroups from simple_parsing.helpers.serialization import save +from ..extraction import Extract from ..metrics import evaluate_preds, to_one_hot from ..run import Run from ..training.supervised import train_supervised @@ -34,6 +35,15 @@ class Elicit(Run): cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" + @staticmethod + def default(): + return Elicit( + data=Extract( + model="", + datasets=("",), + ) + ) + def create_models_dir(self, out_dir: Path): lr_dir = None lr_dir = out_dir / "lr_models"