Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make pytorch an optional dependency #2004

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lm_eval/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_model(model_name):
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
"\nIf you are trying to load a model from Hugging Face, please use the `hf` model type and please ensure torch is installed."
)


Expand Down
19 changes: 16 additions & 3 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import torch

import lm_eval.api.metrics
import lm_eval.api.registry
Expand Down Expand Up @@ -38,6 +37,14 @@
from lm_eval.tasks import Task


try:
import torch

HAS_TORCH = True
except ImportError:
HAS_TORCH = False


@positional_deprecated
def simple_evaluate(
model,
Expand Down Expand Up @@ -145,8 +152,9 @@ def simple_evaluate(
np.random.seed(numpy_random_seed)

if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if HAS_TORCH:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)

if seed_message:
eval_logger.info(" | ".join(seed_message))
Expand Down Expand Up @@ -410,6 +418,8 @@ def evaluate(
requests[reqtype].append(instance)

if lm.world_size > 1:
if not HAS_TORCH:
raise ImportError("torch is required for distributed evaluation")
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
Expand Down Expand Up @@ -504,6 +514,9 @@ def evaluate(
task_output.sample_metrics[(metric, filter_key)].append(value)

if WORLD_SIZE > 1:
if not HAS_TORCH:
raise ImportError("torch is required for distributed evaluation")

# if multigpu, then gather data across all ranks to rank 0
# first gather logged samples across all ranks
for task_output in eval_tasks:
Expand Down
3 changes: 2 additions & 1 deletion lm_eval/loggers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version


Expand Down Expand Up @@ -97,6 +96,8 @@ def get_git_commit_hash():

def add_env_info(storage: Dict[str, Any]):
try:
from torch.utils.collect_env import get_pretty_env_info

pretty_env_info = get_pretty_env_info()
except Exception as err:
pretty_env_info = str(err)
Expand Down
34 changes: 27 additions & 7 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,39 @@
anthropic_llms,
dummy,
gguf,
huggingface,
mamba_lm,
nemo_lm,
neuralmagic,
neuron_optimum,
openai_completions,
optimum_lm,
textsynth,
vllm_causallms,
)


# TODO: implement __all__
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by fix since I needed to do some mucking around in here anyway

__all__ = [
"anthropic_llms",
"dummy",
"gguf",
"openai_completions",
"textsynth",
"vllm_causallms",
]


# try importing all modules that need torch
import importlib


for module_that_needs_torch in [
"huggingface",
"mamba_lm",
"nemo_lm",
"neuralmagic",
"neuron_optimum",
"optimum_lm",
]:
try:
importlib.import_module(f".{module_that_needs_torch}", __name__)
__all__.append(module_that_needs_torch)
except ImportError:
pass


try:
Expand Down
18 changes: 12 additions & 6 deletions lm_eval/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
Union,
)

import torch

try:
import torch
except ImportError:
torch = None

import transformers

from lm_eval.utils import eval_logger
Expand Down Expand Up @@ -141,7 +146,7 @@ def get_original(self, grouped_dict):

def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
tensors: List["torch.Tensor"],
padding_side: Literal["right", "left"] = "right",
):
"""
Expand Down Expand Up @@ -192,10 +197,11 @@ def pad_and_concat(

def clear_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
if torch is not None:
torch.cuda.empty_cache()


def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
def get_dtype(dtype: Union[str, "torch.dtype"]) -> "torch.dtype":
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
Expand Down Expand Up @@ -435,8 +441,8 @@ def get_cache(
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
logits: "torch.Tensor" = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], "torch.Tensor"]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.

Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,24 @@ classifiers = [
requires-python = ">=3.8"
license = { "text" = "MIT" }
dependencies = [
"accelerate>=0.26.0",
"evaluate",
"datasets>=2.16.0",
"evaluate>=0.4.0",
"jsonlines",
"numexpr",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"dill",
"word2number",
"more_itertools",
"jinja2>=3.0.0",
]

[tool.setuptools.packages.find]
Expand All @@ -57,8 +55,10 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"

[project.optional-dependencies]
torch = ["torch>=1.8"]
hf = ["transformers", "torch>=1.8", "accelerate>=0.26.0", "peft>=0.2.0"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two new extras. My guess is that people will generally want these, but pip has no way of specifying a "default" install and a "minimal" install.

anthropic = ["anthropic"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "accelerate>=0.26.0", "peft>=0.2.0"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
Expand All @@ -78,6 +78,8 @@ zeno = ["pandas", "zeno-client"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
unitxt = ["unitxt"]
all = [
"lm_eval[torch]",
"lm_eval[hf]",
"lm_eval[anthropic]",
"lm_eval[dev]",
"lm_eval[deepsparse]",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lm_eval.api as api
import lm_eval.evaluator as evaluator
from lm_eval import tasks
from lm_eval.models.dummy import DummyLM
from lm_eval.utils import make_table


Expand All @@ -15,6 +16,25 @@
# test once we break evaluator into smaller, more manageable pieces


def test_evaluator_with_dummy_lm():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensured this test passes with torch not installed.

task_name = "hellaswag"
limit = 10
lm = DummyLM()

task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict([task_name], task_manager)

e = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
bootstrap_iters=0,
)

# mostly just checking it has a pulse
del e


@pytest.mark.parametrize(
"task_name,limit,model,model_args,bootstrap_iters",
[
Expand Down