Skip to content

Commit

Permalink
Fixes 256 - ValueError mutable default (#263)
Browse files Browse the repository at this point in the history
* Fixes 256 - ValueError mutable default

Fixes issue #256. Error message:

> ValueError: mutable default <class 'elk.training.train.Elicit'> for field run_template is not allowed: use
default_factory

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* To lowercase

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
artkpv and pre-commit-ci[bot] authored Sep 25, 2023
1 parent 14669b1 commit 670eaec
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
10 changes: 2 additions & 8 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import InitVar, dataclass, replace
from dataclasses import InitVar, dataclass, field, replace

import numpy as np
import torch
from datasets import get_dataset_config_info
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenFitterConfig
Expand Down Expand Up @@ -53,12 +52,7 @@ class Sweep:
name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)
run_template: Elicit = field(default_factory=Elicit.default)

def __post_init__(self, add_pooled: bool):
if not self.datasets:
Expand Down
10 changes: 10 additions & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..extraction import Extract
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
Expand All @@ -34,6 +35,15 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

@staticmethod
def default():
return Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)

def create_models_dir(self, out_dir: Path):
lr_dir = None
lr_dir = out_dir / "lr_models"
Expand Down

0 comments on commit 670eaec

Please sign in to comment.