diff --git a/elk/__init__.py b/elk/__init__.py index bb9f9b17..9d95a485 100644 --- a/elk/__init__.py +++ b/elk/__init__.py @@ -1,3 +1,3 @@ -from .extraction import extract_hiddens, Extract +from .extraction import Extract, extract_hiddens __all__ = ["extract_hiddens", "Extract"] diff --git a/elk/calibration.py b/elk/calibration.py index 3d494872..db56fa02 100644 --- a/elk/calibration.py +++ b/elk/calibration.py @@ -1,8 +1,9 @@ +import warnings from dataclasses import dataclass, field -from torch import Tensor from typing import NamedTuple + import torch -import warnings +from torch import Tensor class CalibrationEstimate(NamedTuple): diff --git a/elk/eigsh.py b/elk/eigsh.py index 10c1de60..c7277e32 100644 --- a/elk/eigsh.py +++ b/elk/eigsh.py @@ -1,7 +1,8 @@ -from torch import Tensor from typing import Literal, Optional + import torch import torch.nn.functional as F +from torch import Tensor def lanczos_eigsh( diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 8b00ea3d..fac876fa 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,6 +1,6 @@ from .balanced_sampler import BalancedSampler, FewShotSampler -from .extraction import Extract, extract_hiddens, extract -from .generator import _GeneratorConfig, _GeneratorBuilder +from .extraction import Extract, extract, extract_hiddens +from .generator import _GeneratorBuilder, _GeneratorConfig from .prompt_loading import PromptConfig, load_prompts __all__ = [ diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 878d2935..e2e7f1f7 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,12 +1,14 @@ -from ..math_util import stochastic_round_constrained -from ..utils import infer_label_column -from ..utils.typing import assert_type from collections import deque -from datasets import IterableDataset, Features from itertools import cycle from random import Random +from typing import Iterable, Iterator, Optional + +from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset -from typing import Iterator, Optional, Iterable + +from ..math_util import stochastic_round_constrained +from ..utils import infer_label_column +from ..utils.typing import assert_type class BalancedSampler(TorchIterableDataset): diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index b0b55d0f..187428fc 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -6,9 +6,6 @@ from typing import Iterable, Literal, Optional, Union import torch -from simple_parsing import Serializable, field -from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel - from datasets import ( Array3D, ClassLabel, @@ -20,6 +17,9 @@ Value, get_dataset_config_info, ) +from simple_parsing import Serializable, field +from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel + from elk.utils.typing import float32_to_int16 from ..utils import ( diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index fbf10848..fb4d03bc 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, Optional, Any, Dict +from typing import Any, Callable, Dict, Optional import datasets from datasets.splits import NamedSplit diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 02471a9d..a494be5a 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,3 +1,15 @@ +from dataclasses import dataclass +from random import Random +from typing import Any, Iterator, Literal, Optional + +from datasets import ( + Dataset, + Features, + load_dataset, +) +from datasets.distributed import split_dataset_by_node +from simple_parsing.helpers import Serializable, field + from ..promptsource import DatasetTemplates from ..utils import ( assert_type, @@ -7,16 +19,6 @@ select_train_val_splits, ) from .balanced_sampler import FewShotSampler -from dataclasses import dataclass -from datasets import ( - load_dataset, - Dataset, - Features, -) -from datasets.distributed import split_dataset_by_node -from random import Random -from simple_parsing.helpers import field, Serializable -from typing import Any, Iterator, Literal, Optional @dataclass diff --git a/elk/files.py b/elk/files.py index 47225876..84226706 100644 --- a/elk/files.py +++ b/elk/files.py @@ -1,13 +1,13 @@ """Helper functions for dealing with files.""" -from pathlib import Path import json import os import random +from pathlib import Path from typing import Optional -from simple_parsing import Serializable import yaml +from simple_parsing import Serializable def elk_reporter_dir() -> Path: diff --git a/elk/logging.py b/elk/logging.py index 19bb12c3..706055bd 100644 --- a/elk/logging.py +++ b/elk/logging.py @@ -1,4 +1,5 @@ import logging + from .utils import select_train_val_splits diff --git a/elk/math_util.py b/elk/math_util.py index 7b5cd38c..4ae9daee 100644 --- a/elk/math_util.py +++ b/elk/math_util.py @@ -1,7 +1,8 @@ -from torch import Tensor import math import random + import torch +from torch import Tensor @torch.jit.script diff --git a/elk/parsing.py b/elk/parsing.py index c40a8473..1daded78 100644 --- a/elk/parsing.py +++ b/elk/parsing.py @@ -1,4 +1,5 @@ import re + from .training.losses import LOSSES diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index a19526e1..8f1828f8 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -1,14 +1,14 @@ -from collections import Counter, defaultdict -from jinja2 import BaseLoader, Environment, meta -from pathlib import Path -from shutil import rmtree -from typing import Optional import logging import os import random import uuid -import yaml +from collections import Counter, defaultdict +from pathlib import Path +from shutil import rmtree +from typing import Optional +import yaml +from jinja2 import BaseLoader, Environment, meta # Truncation of jinja template variables # 1710 = 300 words x 4.7 avg characters per word + 300 spaces diff --git a/elk/run.py b/elk/run.py index 7b25ccd7..723338e1 100644 --- a/elk/run.py +++ b/elk/run.py @@ -5,10 +5,10 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Optional, - Union, Callable, Iterator, + Optional, + Union, ) import numpy as np @@ -21,7 +21,7 @@ from elk.extraction.extraction import extract from elk.files import create_output_directory, save_config, save_meta from elk.training.preprocessing import normalize -from elk.utils.csv import write_iterator_to_file, Log +from elk.utils.csv import Log, write_iterator_to_file from elk.utils.data_utils import get_layers, select_train_val_splits from elk.utils.typing import assert_type, int16_to_float32 diff --git a/elk/training/__init__.py b/elk/training/__init__.py index a9d76f05..41264179 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -2,7 +2,6 @@ from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig - __all__ = [ "Reporter", "ReporterConfig", diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 35ee67ec..258f2bc4 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -1,17 +1,19 @@ """An ELK reporter network.""" -from ..parsing import parse_loss -from ..utils.typing import assert_type -from .losses import LOSSES -from .reporter import Reporter, ReporterConfig +import math from copy import deepcopy from dataclasses import dataclass, field -from torch import Tensor -from torch.nn.functional import binary_cross_entropy as bce -from typing import cast, Literal, Optional -import math +from typing import Literal, Optional, cast + import torch import torch.nn as nn +from torch import Tensor +from torch.nn.functional import binary_cross_entropy as bce + +from ..parsing import parse_loss +from ..utils.typing import assert_type +from .losses import LOSSES +from .reporter import Reporter, ReporterConfig @dataclass diff --git a/elk/training/classifier.py b/elk/training/classifier.py index cab88dd9..726cae7a 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor from torch.nn.functional import ( binary_cross_entropy_with_logits as bce_with_logits, +) +from torch.nn.functional import ( cross_entropy, ) -from torch import Tensor -from typing import Optional -import torch @dataclass diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 77582f9f..62e0a543 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -1,12 +1,14 @@ """An ELK reporter network.""" -from ..math_util import cov_mean_fused -from ..eigsh import lanczos_eigsh -from .reporter import Reporter, ReporterConfig from dataclasses import dataclass -from torch import nn, optim, Tensor from typing import Optional + import torch +from torch import Tensor, nn, optim + +from ..eigsh import lanczos_eigsh +from ..math_util import cov_mean_fused +from .reporter import Reporter, ReporterConfig @dataclass diff --git a/elk/training/losses.py b/elk/training/losses.py index d91c1e79..8d7e287b 100644 --- a/elk/training/losses.py +++ b/elk/training/losses.py @@ -1,10 +1,11 @@ """Loss functions for training reporters.""" -from torch import Tensor -import torch import warnings from inspect import signature +import torch +from torch import Tensor + LOSSES = dict() # Registry of loss functions diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index 2802bb9e..6081dcbb 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -1,6 +1,7 @@ """Preprocessing functions for training.""" from typing import Literal + import torch diff --git a/elk/training/reporter.py b/elk/training/reporter.py index ea8a4406..cf3e21e2 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -1,16 +1,18 @@ """An ELK reporter network.""" -from ..calibration import CalibrationError -from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score -from torch import Tensor from typing import Literal, NamedTuple, Optional, Union + import torch import torch.nn as nn +from simple_parsing.helpers import Serializable +from sklearn.metrics import roc_auc_score +from torch import Tensor + +from ..calibration import CalibrationError +from .classifier import Classifier class EvalResult(NamedTuple): diff --git a/elk/training/train.py b/elk/training/train.py index 81a9bde6..caa653ad 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Literal, Optional, Callable +from typing import Callable, Literal, Optional import torch from simple_parsing import Serializable, field, subgroups @@ -16,12 +16,12 @@ from elk.run import Run from elk.utils.typing import assert_type +from ..utils import select_usable_devices from .ccs_reporter import CcsReporter, CcsReporterConfig from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig from .train_log import ElicitLog -from ..utils import select_usable_devices @dataclass diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index db715d70..186b544a 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -5,7 +5,6 @@ infer_num_classes, select_train_val_splits, ) - from .gpu_utils import select_usable_devices from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 diff --git a/elk/utils/csv.py b/elk/utils/csv.py index c13cf1d6..f4f95c47 100644 --- a/elk/utils/csv.py +++ b/elk/utils/csv.py @@ -1,6 +1,6 @@ import csv from pathlib import Path -from typing import Iterator, Callable, TextIO, TypeVar +from typing import Callable, Iterator, TextIO, TypeVar from datasets import DatasetDict diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index e0055484..9636b546 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,5 +1,7 @@ -from .typing import assert_type -from ..promptsource.templates import Template +import copy +from random import Random +from typing import Any, Iterable, List + from datasets import ( ClassLabel, DatasetDict, @@ -7,9 +9,9 @@ Split, Value, ) -from random import Random -from typing import Iterable, List, Any -import copy + +from ..promptsource.templates import Template +from .typing import assert_type def get_columns_all_equal(dataset: DatasetDict) -> list[str]: diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 9074bb8a..f305c6c1 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -1,12 +1,14 @@ """Utilities that use PyNVML to get GPU usage info, and select GPUs accordingly.""" -from .typing import assert_type -from typing import Optional import os +import time +import warnings +from typing import Optional + import pynvml import torch -import warnings -import time + +from .typing import assert_type def select_usable_devices( diff --git a/elk/utils/tree_utils.py b/elk/utils/tree_utils.py index 9a5b0fb1..f084874f 100644 --- a/elk/utils/tree_utils.py +++ b/elk/utils/tree_utils.py @@ -6,7 +6,6 @@ from typing import Callable, Mapping, TypeVar - TreeType = TypeVar("TreeType") diff --git a/elk/utils/typing.py b/elk/utils/typing.py index ea552c83..1d38040e 100644 --- a/elk/utils/typing.py +++ b/elk/utils/typing.py @@ -1,8 +1,7 @@ -from typing import cast, Any, Type, TypeVar +from typing import Any, Type, TypeVar, cast import torch - T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index e22a3513..49223931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,9 +57,9 @@ testpaths = ["tests"] include = ["elk*"] [tool.ruff] -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes +# Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes # See https://beta.ruff.rs/docs/rules/ for more possible rules -select = ["E", "F"] +select = ["E", "F", "I"] # Same as Black. line-length = 88 # Avoid automatically removing unused imports in __init__.py files. diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 85c3c57d..bdc9023d 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -1,7 +1,8 @@ -from elk.training.classifier import Classifier +import torch from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -import torch + +from elk.training.classifier import Classifier @torch.no_grad() diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 9d548f67..58dd6c13 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,6 +1,7 @@ +import torch + from elk.math_util import batch_cov, cov_mean_fused from elk.training import EigenReporter, EigenReporterConfig -import torch def test_eigen_reporter(): diff --git a/tests/test_eigsh.py b/tests/test_eigsh.py index b208dc23..c34ce2b3 100644 --- a/tests/test_eigsh.py +++ b/tests/test_eigsh.py @@ -1,8 +1,9 @@ -from elk.eigsh import lanczos_eigsh -from scipy.sparse.linalg import eigsh import numpy as np import pytest import torch +from scipy.sparse.linalg import eigsh + +from elk.eigsh import lanczos_eigsh @pytest.mark.parametrize("n", [20, 40]) diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index c9a45f03..a5d238fd 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,9 +1,11 @@ -from elk.extraction import load_prompts, PromptConfig -from elk.promptsource.templates import DatasetTemplates from itertools import cycle, islice from typing import Literal + import pytest +from elk.extraction import PromptConfig, load_prompts +from elk.promptsource.templates import DatasetTemplates + @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): diff --git a/tests/test_math.py b/tests/test_math.py index c6db4e93..ee81914e 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,9 +1,12 @@ -from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained -from hypothesis import given, strategies as st -from random import Random import math +from random import Random + import numpy as np import torch +from hypothesis import given +from hypothesis import strategies as st + +from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): diff --git a/tests/test_samplers.py b/tests/test_samplers.py index cb5e1225..87c1ac0c 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -1,10 +1,12 @@ from collections import Counter -from datasets import load_dataset, IterableDataset -from elk.extraction import FewShotSampler, BalancedSampler -from elk.utils import assert_type, infer_label_column from itertools import islice from random import Random +from datasets import IterableDataset, load_dataset + +from elk.extraction import BalancedSampler, FewShotSampler +from elk.utils import assert_type, infer_label_column + def test_output_batches_are_balanced(): # Load an example dataset for testing diff --git a/tests/test_write_iterator_to_file.py b/tests/test_write_iterator_to_file.py index a40d8e2f..25fe7567 100644 --- a/tests/test_write_iterator_to_file.py +++ b/tests/test_write_iterator_to_file.py @@ -1,15 +1,14 @@ import csv +import multiprocessing as mp import time from pathlib import Path from typing import Iterator -import multiprocessing as mp - from datasets import DatasetDict -from elk.utils.csv import write_iterator_to_file from elk.training.reporter import EvalResult from elk.training.train_log import ElicitLog +from elk.utils.csv import write_iterator_to_file def test_write_iterator_to_file(tmp_path: Path):