Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for e2e tests #927

Merged
merged 11 commits into from
Nov 22, 2024
10 changes: 8 additions & 2 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

from loguru import logger

from llmcompressor.core.events import EventType
from llmcompressor.core.helpers import log_model_info, should_log_model_info
from llmcompressor.core.lifecycle import CompressionLifecycle
Expand Down Expand Up @@ -260,12 +262,16 @@ def reset_stage(self):
self.lifecycle.initialized_ = False
self.lifecycle.finalized = False

def get_serialized_recipe(self) -> str:
def get_serialized_recipe(self) -> Optional[str]:
"""
:return: serialized string of the current compiled recipe
"""
recipe = self.lifecycle.recipe_container.compiled_recipe
return recipe.yaml()

if recipe is not None and hasattr(recipe, "yaml"):
return recipe.yaml()

logger.warning("Recipe not found in session - it may have been reset")

def _log_model_info(self):
# Log model level logs if cadence reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def skip(*args, **kwargs):

recipe_path = os.path.join(save_directory, "recipe.yaml")
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

if (recipe_yaml_str := session.get_serialized_recipe()) is not None:
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)
Expand Down
36 changes: 25 additions & 11 deletions tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import shutil
import unittest
from typing import Callable

import pytest
from datasets import load_dataset
from loguru import logger
from parameterized import parameterized, parameterized_class
from transformers import AutoTokenizer

from llmcompressor.core import active_session
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from tests.testing_utils import (
Expand All @@ -22,6 +25,7 @@
vllm_installed = True
except ImportError:
vllm_installed = False
logger.warning("vllm is not installed. This test will be skipped")

# Defines the file paths to the directories containing the test configs
# for each of the quantization schemes
Expand All @@ -32,6 +36,8 @@
WNA16_2of4 = "tests/e2e/vLLM/configs/WNA16_2of4"
CONFIGS = [WNA16, FP8, INT8, ACTORDER, WNA16_2of4]

HF_MODEL_HUB_NAME = "nm-testing"


def gen_test_name(testcase_func: Callable, param_num: int, param: dict) -> str:
return "_".join(
Expand Down Expand Up @@ -76,8 +82,8 @@ class TestvLLM(unittest.TestCase):
save_dir = None

def setUp(self):
print("========== RUNNING ==============")
print(self.scheme)
logger.info("========== RUNNING ==============")
logger.debug(self.scheme)

self.device = "cuda:0"
self.oneshot_kwargs = {}
Expand All @@ -88,6 +94,7 @@ def setUp(self):
"The president of the US is",
"My name is",
]
self.session = active_session()

def test_vllm(self):
import torch
Expand Down Expand Up @@ -124,33 +131,40 @@ def test_vllm(self):
)

# Apply quantization.
print("ONESHOT KWARGS", self.oneshot_kwargs)
logger.debug("ONESHOT KWARGS", self.oneshot_kwargs)
oneshot(
**self.oneshot_kwargs,
clear_sparse_session=True,
oneshot_device=self.device,
)

self.oneshot_kwargs["model"].save_pretrained(self.save_dir)
tokenizer.save_pretrained(self.save_dir)

# Reset after session info is extracted on save -- recipe
self.session.reset()

# Run vLLM with saved model
print("================= RUNNING vLLM =========================")
logger.info("================= RUNNING vLLM =========================")
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
if "W4A16_2of4" in self.scheme:
# required by the kernel
llm = LLM(model=self.save_dir, dtype=torch.float16)
else:
llm = LLM(model=self.save_dir)
outputs = llm.generate(self.prompts, sampling_params)
print("================= vLLM GENERATION ======================")

logger.info("================= vLLM GENERATION ======================")
for output in outputs:
assert output
prompt = output.prompt
generated_text = output.outputs[0].text
print("PROMPT", prompt)
print("GENERATED TEXT", generated_text)
print("================= UPLOADING TO HUB ======================")
self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e")
tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e")
logger.debug("PROMPT", prompt)
logger.debug("GENERATED TEXT", generated_text)

logger.info("================= UPLOADING TO HUB ======================")
hf_upload_path = os.path.join(HF_MODEL_HUB_NAME, f"{self.save_dir}-e2e")
self.oneshot_kwargs["model"].push_to_hub(hf_upload_path)
tokenizer.push_to_hub(hf_upload_path)

def tearDown(self):
if self.save_dir is not None:
Expand Down
Loading