Skip to content

Commit

Permalink
Fix vLLM unload logic when model is None (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Dec 18, 2024
1 parent cf28976 commit c2ae3f1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 72 deletions.
70 changes: 0 additions & 70 deletions src/distilabel/models/llms/_dummy.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/distilabel/models/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def unload(self) -> None:
super().unload()

def _cleanup_vllm_model(self) -> None:
if self._model is None:
return

import torch # noqa
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
Expand Down
26 changes: 25 additions & 1 deletion tests/integration/test_generator_and_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.models.llms._dummy import DummyAsyncLLM
from typing import TYPE_CHECKING, Any

from distilabel.models.llms.base import AsyncLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, LoadDataFromDicts
from distilabel.steps.generators.data_sampler import DataSampler
from distilabel.steps.tasks import TextGeneration

if TYPE_CHECKING:
from distilabel.typing import FormattedInput, GenerateOutput


class DummyAsyncLLM(AsyncLLM):
structured_output: Any = None

def load(self) -> None:
pass

@property
def model_name(self) -> str:
return "test"

async def agenerate( # type: ignore
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
return {
"generations": ["output" for _ in range(num_generations)],
"statistics": {},
}


def get_pipeline():
with Pipeline() as pipe:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/models/llms/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ class Animal(BaseModel):
class TestvLLM:
@pytest.mark.parametrize(
"multi_structured_output",
(True, False),
# TODO: uncomment once with update our code to work with `outlines>0.1.0`
# (True, False),
(False,),
)
@pytest.mark.parametrize(
"num_generations, expected_result",
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/steps/tasks/structured_outputs/test_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class DummyUserTest(BaseModel):
}


@pytest.mark.skip(
reason="won't work until we update our code to work with `outlines>0.1.0`"
)
class TestOutlinesIntegration:
@pytest.mark.parametrize(
"format, schema, prompt",
Expand Down

0 comments on commit c2ae3f1

Please sign in to comment.