Skip to content

Commit

Permalink
Fix for e2e tests (#927)
Browse files Browse the repository at this point in the history
* fix

* fix grammar

* comments

* remove args, use default

* reset session - avoid running finalize more than once on one session

* better comment
  • Loading branch information
horheynm authored Nov 22, 2024
1 parent b61d4e5 commit 19027b2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
10 changes: 8 additions & 2 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

from loguru import logger

from llmcompressor.core.events import EventType
from llmcompressor.core.helpers import log_model_info, should_log_model_info
from llmcompressor.core.lifecycle import CompressionLifecycle
Expand Down Expand Up @@ -260,12 +262,16 @@ def reset_stage(self):
self.lifecycle.initialized_ = False
self.lifecycle.finalized = False

def get_serialized_recipe(self) -> str:
def get_serialized_recipe(self) -> Optional[str]:
"""
:return: serialized string of the current compiled recipe
"""
recipe = self.lifecycle.recipe_container.compiled_recipe
return recipe.yaml()

if recipe is not None and hasattr(recipe, "yaml"):
return recipe.yaml()

logger.warning("Recipe not found in session - it may have been reset")

def _log_model_info(self):
# Log model level logs if cadence reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def skip(*args, **kwargs):

recipe_path = os.path.join(save_directory, "recipe.yaml")
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

if (recipe_yaml_str := session.get_serialized_recipe()) is not None:
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)
Expand Down
36 changes: 25 additions & 11 deletions tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import shutil
import unittest
from typing import Callable

import pytest
from datasets import load_dataset
from loguru import logger
from parameterized import parameterized, parameterized_class
from transformers import AutoTokenizer

from llmcompressor.core import active_session
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from tests.testing_utils import (
Expand All @@ -22,6 +25,7 @@
vllm_installed = True
except ImportError:
vllm_installed = False
logger.warning("vllm is not installed. This test will be skipped")

# Defines the file paths to the directories containing the test configs
# for each of the quantization schemes
Expand All @@ -32,6 +36,8 @@
WNA16_2of4 = "tests/e2e/vLLM/configs/WNA16_2of4"
CONFIGS = [WNA16, FP8, INT8, ACTORDER, WNA16_2of4]

HF_MODEL_HUB_NAME = "nm-testing"


def gen_test_name(testcase_func: Callable, param_num: int, param: dict) -> str:
return "_".join(
Expand Down Expand Up @@ -76,8 +82,8 @@ class TestvLLM(unittest.TestCase):
save_dir = None

def setUp(self):
print("========== RUNNING ==============")
print(self.scheme)
logger.info("========== RUNNING ==============")
logger.debug(self.scheme)

self.device = "cuda:0"
self.oneshot_kwargs = {}
Expand All @@ -88,6 +94,7 @@ def setUp(self):
"The president of the US is",
"My name is",
]
self.session = active_session()

def test_vllm(self):
import torch
Expand Down Expand Up @@ -124,33 +131,40 @@ def test_vllm(self):
)

# Apply quantization.
print("ONESHOT KWARGS", self.oneshot_kwargs)
logger.debug("ONESHOT KWARGS", self.oneshot_kwargs)
oneshot(
**self.oneshot_kwargs,
clear_sparse_session=True,
oneshot_device=self.device,
)

self.oneshot_kwargs["model"].save_pretrained(self.save_dir)
tokenizer.save_pretrained(self.save_dir)

# Reset after session info is extracted on save -- recipe
self.session.reset()

# Run vLLM with saved model
print("================= RUNNING vLLM =========================")
logger.info("================= RUNNING vLLM =========================")
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
if "W4A16_2of4" in self.scheme:
# required by the kernel
llm = LLM(model=self.save_dir, dtype=torch.float16)
else:
llm = LLM(model=self.save_dir)
outputs = llm.generate(self.prompts, sampling_params)
print("================= vLLM GENERATION ======================")

logger.info("================= vLLM GENERATION ======================")
for output in outputs:
assert output
prompt = output.prompt
generated_text = output.outputs[0].text
print("PROMPT", prompt)
print("GENERATED TEXT", generated_text)
print("================= UPLOADING TO HUB ======================")
self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e")
tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e")
logger.debug("PROMPT", prompt)
logger.debug("GENERATED TEXT", generated_text)

logger.info("================= UPLOADING TO HUB ======================")
hf_upload_path = os.path.join(HF_MODEL_HUB_NAME, f"{self.save_dir}-e2e")
self.oneshot_kwargs["model"].push_to_hub(hf_upload_path)
tokenizer.push_to_hub(hf_upload_path)

def tearDown(self):
if self.save_dir is not None:
Expand Down

0 comments on commit 19027b2

Please sign in to comment.