Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GenAI_HFLM class to support microservice. #16

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/model_test_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ jobs:
id: download-artifact
uses: dawidd6/[email protected]
with:
workflow: model-test.yml
workflow: model_test_cpu.yml
name: FinalReport
run_id: ${{ vars.ModelTest_CPU_REF_ID }}
path: ${{ env.OUT_SCRIPT_PATH }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/model_test_hpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
id: download-artifact
uses: dawidd6/[email protected]
with:
workflow: model-test.yml
workflow: model_test_hpu.yml
name: FinalReport
run_id: ${{ vars.ModelTest_HPU_REF_ID }}
path: ${{ env.OUT_SCRIPT_PATH }}
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ jobs:
steps:
- name: Clean Up Working Directory
run: sudo rm -rf ${{github.workspace}}/*

- name: Load environment variables
run:
cat ~/actions-runner4/.env >> $GITHUB_ENV
- name: Checkout out Repo
uses: actions/checkout@v4
with:
Expand All @@ -62,7 +64,7 @@ jobs:

- name: Docker Build
run: |
docker build -f ${{ github.workspace }}/.github/workflows/docker/common.dockerfile -t ${{ env.DOCKER_NAME }}:${{ env.DOCKER_TAG }} .
docker build --build-arg http_proxy="${{ env.HTTP_PROXY_IMAGE_BUILD }}" --build-arg https_proxy="${{ env.HTTPS_PROXY_IMAGE_BUILD }}" -f ${{ github.workspace }}/.github/workflows/docker/common.dockerfile -t ${{ env.DOCKER_NAME }}:${{ env.DOCKER_TAG }} .

- name: Docker Run
run: |
Expand All @@ -71,6 +73,7 @@ jobs:
docker rm -vf ${{ env.CONTAINER_NAME }} || true
fi
docker run -dit --memory="4g" --memory-reservation="1g" --disable-content-trust --privileged --name=${{ env.CONTAINER_NAME }} --shm-size="1g" \
-e http_proxy="${{ env.HTTP_PROXY_CONTAINER_RUN }}" -e https_proxy="${{ env.HTTPS_PROXY_CONTAINER_RUN }}" \
-v ${{ github.workspace }}:/GenAIEval ${{ env.DOCKER_NAME }}:${{ env.DOCKER_TAG }}

- name: Install Dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
from datetime import timedelta
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union

import requests as requests_obj
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs, find_executable_batch_size
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.model import CacheHook, TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, clear_torch_cache, get_dtype, pad_and_concat, stop_sequences_criteria
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from requests.exceptions import RequestException
from tqdm import tqdm
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
Expand Down Expand Up @@ -1217,3 +1221,282 @@ def _model_call(self, inps):
logits = logits[:, :-padding_length, :]
logits = logits.to(torch.float32)
return logits


@register_model("genai-hf")
class GenAI_HFLM(HFLM):
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

def __init__(
self,
base_url=None,
logits_cache: bool = True,
tokenizer: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: int = 1,
max_length: Optional[int] = None,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
**kwargs,
):
self.base_url = base_url
assert self.base_url, "must pass `base_url` to use GenaAI service!"
self._rank = 0
self._world_size = 1

self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)

self.logits_cache = logits_cache
# select (or create) a pad token to use
if self.tokenizer.pad_token:
pass
elif self.tokenizer.unk_token:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
elif (
self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
):
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
# ---
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
# https://github.com/huggingface/transformers/pull/26963
assert self.tokenizer.pad_token_id == 0
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if "GemmaTokenizer" in self.tokenizer.__class__.__name__:
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
)

self._batch_size = int(batch_size)
self._max_length = max_length
self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None:
eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}")
self.cache_hook = CacheHook(None)
self.headers = {"Content-Type": "application/json"}

@property
def max_length(self) -> int:
if self._max_length:
return self._max_length
else:
return self._DEFAULT_MAX_LENGTH

@property
def batch_size(self) -> int:
return self._batch_size

def _loglikelihood_tokens(
self,
task_requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []

def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method."""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end

toks = req[1] + req[2]
return -len(toks), tuple(toks)

def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations."""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]

re_ord = Collator(
task_requests,
sort_fn=_collate,
group_by=None,
group_fn=_lookup_one_token_cont,
)

# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
n_reordered_requests = len(re_ord)
batch_size = self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
else None
)

chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(task_requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks:
inps = []
cont_toks_list = []
inplens = []

conts = []
encoder_attns = []

padding_len_inp = None
padding_len_cont = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying

for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length

# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

# when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :],
dtype=torch.long,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
)
(inplen,) = inp.shape

# build encoder attn masks
encoder_attns.append(torch.ones_like(inp))

cont = torch.tensor(
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
)
(contlen,) = cont.shape

conts.append(cont)

padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen

padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen

inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp]
batched_conts = pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont]
batched_encoder_mask = pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}

data = {
"batched_inputs": batched_inps.tolist(),
}
try:
response = requests_obj.post(
f"{self.base_url}/v1/completions",
headers=self.headers,
data=json.dumps(data),
)
response.raise_for_status()
response = response.json()
except RequestException as e:
eval_logger.error(f"RequestException: {e}")

for (request_str, ctx_tokens, _), greedy_tokens, logprobs, inplen, cont_toks in zip(
chunk, response["greedy_tokens"], response["logprobs"], inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (len(logprobs) - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq]
greedy_tokens = torch.tensor(
self._select_cont_toks(greedy_tokens, contlen=contlen, inplen=ctx_len), dtype=torch.long
).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
cont_logprobs = self._select_cont_toks(logprobs, contlen=contlen, inplen=ctx_len)

# Answer: (log prob, is-exact-match)
answer = (sum(cont_logprobs), bool(max_equal))

res.append(answer)

self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)

pbar.close()

return re_ord.get_original(res)

def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()

@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("loglikelihood_rolling not yet supported for GenAI service")

def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
raise NotImplementedError("Not supported yet.")