Skip to content

Commit

Permalink
Merge branch 'main' into docs/update-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Dec 19, 2023
2 parents 0b33704 + 9a318e4 commit ac492ee
Show file tree
Hide file tree
Showing 15 changed files with 447 additions and 150 deletions.
4 changes: 4 additions & 0 deletions examples/pipeline-llamacpp-and-openai-process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# WARNING: to run this example in Mac OS use:
# no_proxy=* OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES python examples/pipeline-llamacpp-and-openai-process.py
# Otherwise you will get an error when loading the llama.cpp model

import os
from typing import TYPE_CHECKING

Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline-pool-llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_openai(task):
return OpenAILLM(
model="gpt-3.5-turbo",
task=task,
openai_api_key="<OPENAI_API_KEY>",
openai_api_key=os.getenv("OPENAI_API_KEY"),
max_new_tokens=512,
)

Expand Down
63 changes: 63 additions & 0 deletions examples/pipeline-prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from datasets import Dataset
from distilabel.llm import TransformersLLM
from distilabel.pipeline import Pipeline
from distilabel.tasks.critique.prometheus import PrometheusTask
from transformers import AutoTokenizer, LlamaForCausalLM

if __name__ == "__main__":
model = LlamaForCausalLM.from_pretrained(
"kaist-ai/Prometheus-7b-v1.0", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", token=os.getenv("HF_TOKEN")
)
pipeline = Pipeline(
labeller=TransformersLLM(
model=model, # type: ignore
tokenizer=tokenizer,
task=PrometheusTask(
scoring_criteria="Is the provided completion accurate based on the given instruction?",
score_descriptions={
0: "Totaly off-topic and inaccurate",
1: "Incorrect and inaccurate",
2: "Almost correct, but partially inaccurate",
3: "Correct but badly phrased",
4: "Correct and accurate",
},
),
temperature=1.0,
top_p=1.0,
max_new_tokens=512,
),
)

dataset = Dataset.from_dict(
{
"instruction": ["What's the capital of Spain?"],
"completion": ["Paris"],
"ref_completion": ["Madrid"],
}
)

dataset = pipeline.generate(
dataset, # type: ignore
display_progress_bar=True,
skip_dry_run=True,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"rich >= 13.5.0",
"tenacity >= 8",
"importlib-resources >= 6.1.1; python_version < '3.9'",
"multiprocess",
]
dynamic = ["version"]

Expand Down
63 changes: 18 additions & 45 deletions src/distilabel/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

from __future__ import annotations

import multiprocessing as mp
import queue
import random
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from ctypes import c_char
from enum import Enum, auto
from functools import cached_property
from threading import Thread
from typing import (
Expand All @@ -35,9 +33,11 @@
Union,
)

import multiprocess as mp

from distilabel.logger import get_logger
from distilabel.tasks.prompt import Prompt
from distilabel.utils.types import is_list_of_futures
from distilabel.utils.futures import when_all_complete

if TYPE_CHECKING:
from distilabel.llm.utils import LLMOutput
Expand All @@ -47,22 +47,6 @@
logger = get_logger()


class LLMFutures(Enum):
"""An enum used to indicate whether the `LLM` returns futures or not, and if so
what the futures contains.
Attributes:
CONTAINS_ROWS: used to indicate that the `LLM` will return `Future`s that
contains a `List[List[LLMOutput]]`. The first list just contains 1 element
(the row) and the second list contains the generations for that row.
CONTAINS_BATCHES: used to indicate that the `LLM` will return `Future`s that
contains a `List[List[LLMOutput]]`.
"""

CONTAINS_ROWS = auto()
CONTAINS_BATCHES = auto()


class LLM(ABC):
def __init__(
self,
Expand Down Expand Up @@ -198,7 +182,7 @@ def generate(
inputs: List[Dict[str, Any]],
num_generations: int = 1,
progress_callback_func: Union[Callable, None] = None,
) -> Union[List[Future[List["LLMOutput"]]], List[List["LLMOutput"]]]:
) -> Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]:
"""Generates the outputs for the given inputs using the LLM.
Args:
Expand All @@ -214,32 +198,27 @@ def generate(

def _progress():
if progress_callback_func is not None:
advance = (
num_generations * len(inputs)
if self.return_futures is not None
else num_generations
)
progress_callback_func(advance=advance)
progress_callback_func(advance=num_generations * len(inputs))

if self.thread_pool_executor is not None:
futures = []
for input in inputs:
future = self.thread_pool_executor.submit(
self._generate, [input], num_generations
)
future.add_done_callback(lambda _: _progress())
futures.append(future)
return futures
future = when_all_complete(futures)
future.add_done_callback(lambda _: _progress())
return future

generations = self._generate(inputs, num_generations)
_progress()
return generations

@property
def return_futures(self) -> Union[LLMFutures, None]:
"""Returns whether the LLM returns futures or not, and if so what the futures
contains."""
return LLMFutures.CONTAINS_ROWS
def return_futures(self) -> bool:
"""Whether the `LLM` returns futures"""
return True


MAX_MODEL_NAME_LENGTH = 256
Expand Down Expand Up @@ -339,13 +318,7 @@ def run(self) -> None:
inputs=request.inputs, num_generations=request.num_generations
)

# Generations are a list of `Future`s because the `LLM` is using a thread pool
if is_list_of_futures(results):
generations = []
for future in results:
generations.extend(future.result())
else:
generations = results
generations = results.result() if isinstance(results, Future) else results

self._result_queue.put(_TextGenerationResult(generations))

Expand Down Expand Up @@ -591,10 +564,9 @@ def model_name(self) -> str:
return "".join([c.decode() for c in self._model_name if c != b"\0"])

@property
def return_futures(self) -> Union[LLMFutures, None]:
"""Returns whether the LLM returns futures or not, and if so what the futures
contains."""
return LLMFutures.CONTAINS_BATCHES
def return_futures(self) -> bool:
"""Whether the `LLM` returns futures"""
return True


class LLMPool:
Expand Down Expand Up @@ -719,5 +691,6 @@ def task(self) -> "Task":
return self.llms[0].task

@property
def return_futures(self) -> Union[LLMFutures, None]:
return None
def return_futures(self) -> bool:
"""Whether the `LLM` returns futures"""
return False
Loading

0 comments on commit ac492ee

Please sign in to comment.