Skip to content

Commit

Permalink
Turn on import sorting (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose authored Apr 5, 2023
1 parent 1dc6023 commit 7f0aaf1
Show file tree
Hide file tree
Showing 36 changed files with 123 additions and 94 deletions.
2 changes: 1 addition & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .extraction import extract_hiddens, Extract
from .extraction import Extract, extract_hiddens

__all__ = ["extract_hiddens", "Extract"]
5 changes: 3 additions & 2 deletions elk/calibration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from dataclasses import dataclass, field
from torch import Tensor
from typing import NamedTuple

import torch
import warnings
from torch import Tensor


class CalibrationEstimate(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion elk/eigsh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor
from typing import Literal, Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def lanczos_eigsh(
Expand Down
4 changes: 2 additions & 2 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .extraction import Extract, extract, extract_hiddens
from .generator import _GeneratorBuilder, _GeneratorConfig
from .prompt_loading import PromptConfig, load_prompts

__all__ = [
Expand Down
12 changes: 7 additions & 5 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.typing import assert_type
from collections import deque
from datasets import IterableDataset, Features
from itertools import cycle
from random import Random
from typing import Iterable, Iterator, Optional

from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional, Iterable

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.typing import assert_type


class BalancedSampler(TorchIterableDataset):
Expand Down
6 changes: 3 additions & 3 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from typing import Iterable, Literal, Optional, Union

import torch
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from datasets import (
Array3D,
ClassLabel,
Expand All @@ -20,6 +17,9 @@
Value,
get_dataset_config_info,
)
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from elk.utils.typing import float32_to_int16

from ..utils import (
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, Optional, Any, Dict
from typing import Any, Callable, Dict, Optional

import datasets
from datasets.splits import NamedSplit
Expand Down
22 changes: 12 additions & 10 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Iterator, Literal, Optional

from datasets import (
Dataset,
Features,
load_dataset,
)
from datasets.distributed import split_dataset_by_node
from simple_parsing.helpers import Serializable, field

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
Expand All @@ -7,16 +19,6 @@
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
from dataclasses import dataclass
from datasets import (
load_dataset,
Dataset,
Features,
)
from datasets.distributed import split_dataset_by_node
from random import Random
from simple_parsing.helpers import field, Serializable
from typing import Any, Iterator, Literal, Optional


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions elk/files.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Helper functions for dealing with files."""

from pathlib import Path
import json
import os
import random
from pathlib import Path
from typing import Optional

from simple_parsing import Serializable
import yaml
from simple_parsing import Serializable


def elk_reporter_dir() -> Path:
Expand Down
1 change: 1 addition & 0 deletions elk/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging

from .utils import select_train_val_splits


Expand Down
3 changes: 2 additions & 1 deletion elk/math_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor
import math
import random

import torch
from torch import Tensor


@torch.jit.script
Expand Down
1 change: 1 addition & 0 deletions elk/parsing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re

from .training.losses import LOSSES


Expand Down
12 changes: 6 additions & 6 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections import Counter, defaultdict
from jinja2 import BaseLoader, Environment, meta
from pathlib import Path
from shutil import rmtree
from typing import Optional
import logging
import os
import random
import uuid
import yaml
from collections import Counter, defaultdict
from pathlib import Path
from shutil import rmtree
from typing import Optional

import yaml
from jinja2 import BaseLoader, Environment, meta

# Truncation of jinja template variables
# 1710 = 300 words x 4.7 avg characters per word + 300 spaces
Expand Down
6 changes: 3 additions & 3 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Optional,
Union,
Callable,
Iterator,
Optional,
Union,
)

import numpy as np
Expand All @@ -21,7 +21,7 @@
from elk.extraction.extraction import extract
from elk.files import create_output_directory, save_config, save_meta
from elk.training.preprocessing import normalize
from elk.utils.csv import write_iterator_to_file, Log
from elk.utils.csv import Log, write_iterator_to_file
from elk.utils.data_utils import get_layers, select_train_val_splits
from elk.utils.typing import assert_type, int16_to_float32

Expand Down
1 change: 0 additions & 1 deletion elk/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import OptimConfig, Reporter, ReporterConfig


__all__ = [
"Reporter",
"ReporterConfig",
Expand Down
18 changes: 10 additions & 8 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""An ELK reporter network."""

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .losses import LOSSES
from .reporter import Reporter, ReporterConfig
import math
from copy import deepcopy
from dataclasses import dataclass, field
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce
from typing import cast, Literal, Optional
import math
from typing import Literal, Optional, cast

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import binary_cross_entropy as bce

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .losses import LOSSES
from .reporter import Reporter, ReporterConfig


@dataclass
Expand Down
9 changes: 6 additions & 3 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor
from torch.nn.functional import (
binary_cross_entropy_with_logits as bce_with_logits,
)
from torch.nn.functional import (
cross_entropy,
)
from torch import Tensor
from typing import Optional
import torch


@dataclass
Expand Down
10 changes: 6 additions & 4 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""An ELK reporter network."""

from ..math_util import cov_mean_fused
from ..eigsh import lanczos_eigsh
from .reporter import Reporter, ReporterConfig
from dataclasses import dataclass
from torch import nn, optim, Tensor
from typing import Optional

import torch
from torch import Tensor, nn, optim

from ..eigsh import lanczos_eigsh
from ..math_util import cov_mean_fused
from .reporter import Reporter, ReporterConfig


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions elk/training/losses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Loss functions for training reporters."""

from torch import Tensor
import torch
import warnings
from inspect import signature

import torch
from torch import Tensor

LOSSES = dict() # Registry of loss functions


Expand Down
1 change: 1 addition & 0 deletions elk/training/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Preprocessing functions for training."""

from typing import Literal

import torch


Expand Down
12 changes: 7 additions & 5 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""An ELK reporter network."""

from ..calibration import CalibrationError
from .classifier import Classifier
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
from torch import Tensor
from typing import Literal, NamedTuple, Optional, Union

import torch
import torch.nn as nn
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
from torch import Tensor

from ..calibration import CalibrationError
from .classifier import Classifier


class EvalResult(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Literal, Optional, Callable
from typing import Callable, Literal, Optional

import torch
from simple_parsing import Serializable, field, subgroups
Expand All @@ -16,12 +16,12 @@
from elk.run import Run
from elk.utils.typing import assert_type

from ..utils import select_usable_devices
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .classifier import Classifier
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import OptimConfig, Reporter, ReporterConfig
from .train_log import ElicitLog
from ..utils import select_usable_devices


@dataclass
Expand Down
1 change: 0 additions & 1 deletion elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
infer_num_classes,
select_train_val_splits,
)

from .gpu_utils import select_usable_devices
from .tree_utils import pytree_map
from .typing import assert_type, float32_to_int16, int16_to_float32
Expand Down
2 changes: 1 addition & 1 deletion elk/utils/csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
from pathlib import Path
from typing import Iterator, Callable, TextIO, TypeVar
from typing import Callable, Iterator, TextIO, TypeVar

from datasets import DatasetDict

Expand Down
12 changes: 7 additions & 5 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from .typing import assert_type
from ..promptsource.templates import Template
import copy
from random import Random
from typing import Any, Iterable, List

from datasets import (
ClassLabel,
DatasetDict,
Features,
Split,
Value,
)
from random import Random
from typing import Iterable, List, Any
import copy

from ..promptsource.templates import Template
from .typing import assert_type


def get_columns_all_equal(dataset: DatasetDict) -> list[str]:
Expand Down
10 changes: 6 additions & 4 deletions elk/utils/gpu_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Utilities that use PyNVML to get GPU usage info, and select GPUs accordingly."""

from .typing import assert_type
from typing import Optional
import os
import time
import warnings
from typing import Optional

import pynvml
import torch
import warnings
import time

from .typing import assert_type


def select_usable_devices(
Expand Down
Loading

0 comments on commit 7f0aaf1

Please sign in to comment.