From 770861806d497b9dd8232184c12bdfbb55c9942a Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 20 Jun 2024 14:05:19 -0700 Subject: [PATCH 1/2] make pytorch an optional dependency --- lm_eval/api/registry.py | 1 + lm_eval/evaluator.py | 19 ++++++++++++++++--- lm_eval/loggers/utils.py | 3 ++- lm_eval/models/__init__.py | 34 +++++++++++++++++++++++++++------- lm_eval/models/utils.py | 18 ++++++++++++------ pyproject.toml | 10 ++++++---- tests/test_evaluator.py | 38 +++++++++++++++++++++++++++++++++++++- 7 files changed, 101 insertions(+), 22 deletions(-) diff --git a/lm_eval/api/registry.py b/lm_eval/api/registry.py index 7446a429e6..f15c21f475 100644 --- a/lm_eval/api/registry.py +++ b/lm_eval/api/registry.py @@ -37,6 +37,7 @@ def get_model(model_name): except KeyError: raise ValueError( f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" + "\nIf you are trying to load a model from Hugging Face, please use the `hf` model type and please ensure torch is installed." ) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index bca48c111d..7c92ed5fae 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, List, Optional, Union import numpy as np -import torch import lm_eval.api.metrics import lm_eval.api.registry @@ -38,6 +37,14 @@ from lm_eval.tasks import Task +try: + import torch + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + + @positional_deprecated def simple_evaluate( model, @@ -145,8 +152,9 @@ def simple_evaluate( np.random.seed(numpy_random_seed) if torch_random_seed is not None: - seed_message.append(f"Setting torch manual seed to {torch_random_seed}") - torch.manual_seed(torch_random_seed) + if HAS_TORCH: + seed_message.append(f"Setting torch manual seed to {torch_random_seed}") + torch.manual_seed(torch_random_seed) if seed_message: eval_logger.info(" | ".join(seed_message)) @@ -410,6 +418,8 @@ def evaluate( requests[reqtype].append(instance) if lm.world_size > 1: + if not HAS_TORCH: + raise ImportError("torch is required for distributed evaluation") instances_rnk = torch.tensor(len(task._instances), device=lm.device) gathered_item = ( lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() @@ -504,6 +514,9 @@ def evaluate( task_output.sample_metrics[(metric, filter_key)].append(value) if WORLD_SIZE > 1: + if not HAS_TORCH: + raise ImportError("torch is required for distributed evaluation") + # if multigpu, then gather data across all ranks to rank 0 # first gather logged samples across all ranks for task_output in eval_tasks: diff --git a/lm_eval/loggers/utils.py b/lm_eval/loggers/utils.py index fd47c9ab27..357359a2bd 100644 --- a/lm_eval/loggers/utils.py +++ b/lm_eval/loggers/utils.py @@ -6,7 +6,6 @@ from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from torch.utils.collect_env import get_pretty_env_info from transformers import __version__ as trans_version @@ -97,6 +96,8 @@ def get_git_commit_hash(): def add_env_info(storage: Dict[str, Any]): try: + from torch.utils.collect_env import get_pretty_env_info + pretty_env_info = get_pretty_env_info() except Exception as err: pretty_env_info = str(err) diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index 698c912f27..db4e0d7ffe 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -2,19 +2,39 @@ anthropic_llms, dummy, gguf, - huggingface, - mamba_lm, - nemo_lm, - neuralmagic, - neuron_optimum, openai_completions, - optimum_lm, textsynth, vllm_causallms, ) -# TODO: implement __all__ +__all__ = [ + "anthropic_llms", + "dummy", + "gguf", + "openai_completions", + "textsynth", + "vllm_causallms", +] + + +# try importing all modules that need torch +import importlib + + +for module_that_needs_torch in [ + "huggingface", + "mamba_lm", + "nemo_lm", + "neuralmagic", + "neuron_optimum", + "optimum_lm", +]: + try: + importlib.import_module(f".{module_that_needs_torch}", __name__) + __all__.append(module_that_needs_torch) + except ImportError: + pass try: diff --git a/lm_eval/models/utils.py b/lm_eval/models/utils.py index 09818f4edd..1703711325 100644 --- a/lm_eval/models/utils.py +++ b/lm_eval/models/utils.py @@ -18,7 +18,12 @@ Union, ) -import torch + +try: + import torch +except ImportError: + torch = None + import transformers from lm_eval.utils import eval_logger @@ -141,7 +146,7 @@ def get_original(self, grouped_dict): def pad_and_concat( max_length: int, - tensors: List[torch.Tensor], + tensors: List["torch.Tensor"], padding_side: Literal["right", "left"] = "right", ): """ @@ -192,10 +197,11 @@ def pad_and_concat( def clear_torch_cache() -> None: gc.collect() - torch.cuda.empty_cache() + if torch is not None: + torch.cuda.empty_cache() -def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: +def get_dtype(dtype: Union[str, "torch.dtype"]) -> "torch.dtype": """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" if isinstance(dtype, str) and dtype != "auto": # Convert `str` args torch dtype: `float16` -> `torch.float16` @@ -435,8 +441,8 @@ def get_cache( req_str: Tuple[str, str] = None, cxt_toks: List[int] = None, cont_toks: List[int] = None, - logits: torch.Tensor = None, - ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: + logits: "torch.Tensor" = None, + ) -> Iterator[Tuple[Tuple[str, str], List[int], "torch.Tensor"]]: """ Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. diff --git a/pyproject.toml b/pyproject.toml index 3818a1a80a..25b632b631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,26 +19,24 @@ classifiers = [ requires-python = ">=3.8" license = { "text" = "MIT" } dependencies = [ - "accelerate>=0.26.0", "evaluate", "datasets>=2.16.0", "evaluate>=0.4.0", "jsonlines", "numexpr", - "peft>=0.2.0", "pybind11>=2.6.2", "pytablewriter", "rouge-score>=0.0.4", "sacrebleu>=1.5.0", "scikit-learn>=0.24.1", "sqlitedict", - "torch>=1.8", "tqdm-multiprocess", "transformers>=4.1", "zstandard", "dill", "word2number", "more_itertools", + "jinja2>=3.0.0", ] [tool.setuptools.packages.find] @@ -57,8 +55,10 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" Repository = "https://github.com/EleutherAI/lm-evaluation-harness" [project.optional-dependencies] +torch = ["torch>=1.8"] +hf = ["transformers", "torch>=1.8", "accelerate>=0.26.0", "peft>=0.2.0"] anthropic = ["anthropic"] -dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "accelerate>=0.26.0", "peft>=0.2.0"] deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"] gptq = ["auto-gptq[triton]>=0.6.0"] hf_transfer = ["hf_transfer"] @@ -78,6 +78,8 @@ zeno = ["pandas", "zeno-client"] wandb = ["wandb>=0.16.3", "pandas", "numpy"] unitxt = ["unitxt"] all = [ + "lm_eval[torch]", + "lm_eval[hf]", "lm_eval[anthropic]", "lm_eval[dev]", "lm_eval[deepsparse]", diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index d5edf9aec2..14764fc3ed 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,12 +1,14 @@ import os import re -from typing import List +from typing import List, Tuple import pytest import lm_eval.api as api import lm_eval.evaluator as evaluator from lm_eval import tasks +from lm_eval.api.model import LM +from lm_eval.models.dummy import DummyLM from lm_eval.utils import make_table @@ -15,6 +17,40 @@ # test once we break evaluator into smaller, more manageable pieces +class FakeLM(LM): + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + return [(-100.0, False) for _ in requests] + + def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: + return [(-100.0,) for _ in requests] + + def generate_until(self, requests) -> List[str]: + output = [] + for request in requests: + output.append("ZZZ" + request.until) + + return output + + +def test_evaluator_with_dummy_lm(): + task_name = "hellaswag" + limit = 10 + lm = DummyLM() + + task_manager = tasks.TaskManager() + task_dict = tasks.get_task_dict([task_name], task_manager) + + e = evaluator.evaluate( + lm=lm, + task_dict=task_dict, + limit=limit, + bootstrap_iters=0, + ) + + # mostly just checking it has a pulse + del e + + @pytest.mark.parametrize( "task_name,limit,model,model_args,bootstrap_iters", [ From ecd1bd622acdd602215e4d05b7a83563413fa985 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 20 Jun 2024 14:10:08 -0700 Subject: [PATCH 2/2] remove FakeLM --- tests/test_evaluator.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 14764fc3ed..c6bf77ebcd 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,13 +1,12 @@ import os import re -from typing import List, Tuple +from typing import List import pytest import lm_eval.api as api import lm_eval.evaluator as evaluator from lm_eval import tasks -from lm_eval.api.model import LM from lm_eval.models.dummy import DummyLM from lm_eval.utils import make_table @@ -17,21 +16,6 @@ # test once we break evaluator into smaller, more manageable pieces -class FakeLM(LM): - def loglikelihood(self, requests) -> List[Tuple[float, bool]]: - return [(-100.0, False) for _ in requests] - - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: - return [(-100.0,) for _ in requests] - - def generate_until(self, requests) -> List[str]: - output = [] - for request in requests: - output.append("ZZZ" + request.until) - - return output - - def test_evaluator_with_dummy_lm(): task_name = "hellaswag" limit = 10