Skip to content

Commit

Permalink
Update Unitxt task to use locally installed unitxt and not download U…
Browse files Browse the repository at this point in the history
…nitxt code from Huggingface (#2514)

* Moved to require unitxt installation and not download unitxt from HF hub.

This has performance benefits and simplifies the code.

Signed-off-by: Yoav Katz <[email protected]>

* Updated watsonx documentation

* Updated installation instructions

* Removed redundant comman

* Allowed unitxt tasks to generate chat APIs

Modified WatsonXI model to support chat apis

* Removed print

* Run precommit formatting

---------

Signed-off-by: Yoav Katz <[email protected]>
  • Loading branch information
yoavkatz authored Dec 1, 2024
1 parent 0230356 commit 1170ef9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 15 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Note that for externally hosted models, configs such as `--device` which relate
| Neuron via AWS Inf2 (Causal LMs) | ✔️ | `neuronx` | Any decoder-only AutoModelForCausalLM supported to run on [huggingface-ami image for inferentia2](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| [Neural Magic DeepSparse](https://github.com/neuralmagic/deepsparse) | ✔️ | `deepsparse` | Any LM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub with the "deepsparse" tag](https://huggingface.co/models?other=deepsparse) | `generate_until`, `loglikelihood` | ... |
| [Neural Magic SparseML](https://github.com/neuralmagic/sparseml) | ✔️ | `sparseml` | Any decoder-only AutoModelForCausalLM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub](https://huggingface.co/neuralmagic). Especially useful for models with quantization like [`zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized`](https://sparsezoo.neuralmagic.com/models/llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| Watsonx.ai | :heavy_check_mark: | `watsonx_llm` | [Supported Watsonx.ai Engines](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx) | `generate_until` `loglikelihood` |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` | Support for OpenAI API-compatible servers, with easy customization for other APIs. | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | ... |

Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
Expand Down Expand Up @@ -476,6 +477,8 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
| gptq | For loading models with GPTQ |
| hf_transfer | For speeding up HF Hub file downloads |
| ifeval | For running the IFEval task |
| ibm_watsonx_ai | For using IBM watsonx.ai model apis |

| neuronx | For running on AWS inf2 instances |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
Expand Down
19 changes: 18 additions & 1 deletion lm_eval/models/ibm_watsonx_ai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import json
import os
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
Expand All @@ -8,6 +9,7 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import JsonChatStr
from lm_eval.utils import eval_logger, simple_parse_args_string


Expand Down Expand Up @@ -248,7 +250,12 @@ def generate_until(self, requests: List[Instance]) -> List[str]:
):
context, continuation = request
try:
response = self.model.generate_text(context, self.generate_params)
if isinstance(context, JsonChatStr):
context = json.loads(context.prompt)
response = self.model.chat(context, self.generate_params)
response = response["choices"][0]["message"]["content"]
else:
response = self.model.generate_text(context, self.generate_params)
except Exception as exp:
eval_logger.error("Error while generating text.")
raise exp
Expand Down Expand Up @@ -372,3 +379,13 @@ def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
)

return cast(List[Tuple[float, bool]], results)

@property
def tokenizer_name(self) -> str:
return ""

def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> List[Dict[str, str]]:
# A hack similar from api_model to allow encoding for cache
return JsonChatStr(json.dumps(chat_history))
2 changes: 2 additions & 0 deletions lm_eval/tasks/unitxt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The full Unitxt catalog can be viewed in an [online explorer](https://unitxt.rea

Read more about Unitxt at [www.unitxt.ai](https://www.unitxt.ai/).

To use Unitxt dataset with lm-eval, you should first install unitxt via 'pip install unitxt'.

### Paper

Title: `Unitxt: Flexible, Shareable and Reusable Data Preparation and Evaluation for Generative AI`
Expand Down
49 changes: 35 additions & 14 deletions lm_eval/tasks/unitxt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import importlib.util
import re
from collections.abc import Callable
from functools import partial
from typing import Any, Dict, Optional

import datasets
import evaluate

from lm_eval.api.instance import Instance
from lm_eval.api.task import ConfigurableTask
Expand All @@ -28,16 +28,21 @@
"""


def is_unitxt_installed() -> bool:
return importlib.util.find_spec("unitxt") is not None
def assert_unitxt_installed():
if importlib.util.find_spec("unitxt") is None:
raise Exception(
"Please install unitxt via 'pip install unitxt'. For more information see: https://www.unitxt.ai/"
)


def score(items, metric):
predictions, references = zip(*items)
evaluator = evaluate.load("unitxt/metric")
assert_unitxt_installed()
from unitxt import evaluate

for reference in references:
reference["metrics"] = [metric]
results = evaluator.compute(predictions=predictions, references=references)
results = evaluate(predictions, references)
return results[0]["score"]["global"]["score"]


Expand All @@ -61,16 +66,10 @@ def __init__(
self.metrics = self.dataset["test"][0]["metrics"]

def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
if is_unitxt_installed():
from unitxt import load_dataset
assert_unitxt_installed()
from unitxt import load_dataset

self.dataset = load_dataset(self.DATASET_NAME)
else:
self.dataset = datasets.load_dataset(
name=self.DATASET_NAME,
path="unitxt/data",
trust_remote_code=True,
)
self.dataset = load_dataset(self.DATASET_NAME, disable_cache=False)

def has_training_docs(self):
return "train" in self.dataset
Expand Down Expand Up @@ -102,6 +101,27 @@ def doc_to_target(self, doc):
def get_arguments(self, doc, ctx):
return (ctx, {"until": ["\n"]})

def fewshot_context(
self,
doc: str,
num_fewshot: int,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
) -> str:
source = self.doc_to_text(doc)
if isinstance(source, list):
if apply_chat_template:
formated_source = chat_template(self.doc_to_text(doc))
return formated_source
else:
raise Exception(
"Got chat template format from Unitxt, but apply_chat_template is false. Add '--apply_chat_template' to command line."
)
else:
return source

def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Expand All @@ -113,6 +133,7 @@ def construct_requests(self, doc, ctx, **kwargs):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
kwargs.pop("apply_chat_template", False) # Not used by unitxt
return [
Instance(
request_type="generate_until",
Expand Down

0 comments on commit 1170ef9

Please sign in to comment.