Skip to content

Commit

Permalink
Move Enum to Dict[str, str] to avoid serialization errors during cach…
Browse files Browse the repository at this point in the history
…ing (#482)

* Use multiprocessing instead of multiprocess for mental health

* Update Enum fields to Dict[str, str] to avoid serialization errors

* Remove commented code
  • Loading branch information
plaguss authored Mar 26, 2024
1 parent da3730b commit 448ffdb
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 134 deletions.
3 changes: 1 addition & 2 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp
import signal
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, cast

import multiprocess as mp

from distilabel.llm.mixins import CudaDevicePlacementMixin
from distilabel.pipeline.base import BasePipeline, _Batch, _BatchManager, _WriteBuffer
from distilabel.steps.base import Step
Expand Down
23 changes: 6 additions & 17 deletions src/distilabel/steps/task/evol_instruct/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

from distilabel.utils.lists import flatten_responses

if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
from enum import EnumType

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
Expand All @@ -30,8 +21,9 @@
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import StepInput
from distilabel.steps.task.base import Task
from distilabel.steps.task.evol_instruct.utils import MutationTemplates
from distilabel.steps.task.evol_instruct.utils import MUTATION_TEMPLATES
from distilabel.steps.task.typing import ChatType
from distilabel.utils.lists import flatten_responses

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput
Expand Down Expand Up @@ -62,7 +54,7 @@ class EvolInstruct(Task):
num_evolutions: int
store_evolutions: bool = False
generate_answers: bool = False
mutation_templates: EnumType = Field(default=MutationTemplates)
mutation_templates: Dict[str, str] = MUTATION_TEMPLATES

seed: RuntimeParameter[int] = Field(
default=42,
Expand Down Expand Up @@ -130,11 +122,8 @@ def format_output(

@property
def mutation_templates_names(self) -> List[str]:
"""Returns the names i.e. keys of the provided `mutation_templates` enum."""
return [
member.name # type: ignore
for member in self.mutation_templates.__members__.values() # type: ignore
]
"""Returns the names i.e. keys of the provided `mutation_templates`."""
return list(self.mutation_templates.keys())

def _apply_random_mutation(self, instruction: str) -> str:
"""Applies a random mutation from the ones provided as part of the `mutation_templates`
Expand All @@ -147,7 +136,7 @@ def _apply_random_mutation(self, instruction: str) -> str:
A random mutation prompt with the provided instruction.
"""
mutation = np.random.choice(self.mutation_templates_names)
return self.mutation_templates[mutation].value.replace("<PROMPT>", instruction) # type: ignore
return self.mutation_templates[mutation].replace("<PROMPT>", instruction) # type: ignore

def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]:
"""Evolves the instructions provided as part of the inputs of the task.
Expand Down
13 changes: 3 additions & 10 deletions src/distilabel/steps/task/evol_instruct/evol_complexity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

from pydantic import Field
from typing import Dict

from distilabel.steps.task.evol_instruct.base import EvolInstruct
from distilabel.steps.task.evol_instruct.evol_complexity.utils import MutationTemplates

if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
from enum import EnumType
from distilabel.steps.task.evol_instruct.evol_complexity.utils import MUTATION_TEMPLATES


class EvolComplexity(EvolInstruct):
Expand Down Expand Up @@ -50,4 +43,4 @@ class EvolComplexity(EvolInstruct):
- `store_evolutions=True`, `generate_answers=True` -> (evolved_instructions, model_name, answer)
"""

mutation_templates: EnumType = Field(default=MutationTemplates)
mutation_templates: Dict[str, str] = MUTATION_TEMPLATES
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

from pydantic import Field
from typing import Dict

from distilabel.steps.task.evol_instruct.evol_complexity.utils import (
GenerationMutationTemplates,
GENERATION_MUTATION_TEMPLATES,
)
from distilabel.steps.task.evol_instruct.generator import EvolInstructGenerator

if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
from enum import EnumType


class EvolComplexityGenerator(EvolInstructGenerator):
"""
Expand All @@ -52,4 +45,4 @@ class EvolComplexityGenerator(EvolInstructGenerator):
- `generate_answers=True` -> (instruction, model_name, answer)
"""

mutation_templates: EnumType = Field(default=GenerationMutationTemplates)
mutation_templates: Dict[str, str] = GENERATION_MUTATION_TEMPLATES
45 changes: 19 additions & 26 deletions src/distilabel/steps/task/evol_instruct/evol_complexity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

from distilabel.steps.task.evol_instruct.utils import (
GenerationMutationTemplates as GenerationMutationTemplatesEvolInstruct,
GENERATION_MUTATION_TEMPLATES as GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT,
)
from distilabel.steps.task.evol_instruct.utils import (
MutationTemplates as MutationTemplatesEvolInstruct,
MUTATION_TEMPLATES as MUTATION_TEMPLATES_EVOL_INSTRUCT,
)

if sys.version_info < (3, 11):
from enum import Enum as StrEnum
else:
from enum import StrEnum


class MutationTemplates(StrEnum):
CONSTRAINTS = MutationTemplatesEvolInstruct.CONSTRAINTS.value
DEEPENING = MutationTemplatesEvolInstruct.DEEPENING.value
CONCRETIZING = MutationTemplatesEvolInstruct.CONCRETIZING.value
INCREASED_REASONING_STEPS = (
MutationTemplatesEvolInstruct.INCREASED_REASONING_STEPS.value
)

MUTATION_TEMPLATES = {
"CONSTRAINTS": MUTATION_TEMPLATES_EVOL_INSTRUCT["CONSTRAINTS"],
"DEEPENING": MUTATION_TEMPLATES_EVOL_INSTRUCT["DEEPENING"],
"CONCRETIZING": MUTATION_TEMPLATES_EVOL_INSTRUCT["CONCRETIZING"],
"INCREASED_REASONING_STEPS": MUTATION_TEMPLATES_EVOL_INSTRUCT[
"INCREASED_REASONING_STEPS"
],
}

class GenerationMutationTemplates(StrEnum):
FRESH_START = GenerationMutationTemplatesEvolInstruct.FRESH_START.value
CONSTRAINTS = GenerationMutationTemplatesEvolInstruct.CONSTRAINTS.value
DEEPENING = GenerationMutationTemplatesEvolInstruct.DEEPENING.value
CONCRETIZING = GenerationMutationTemplatesEvolInstruct.CONCRETIZING.value
INCREASED_REASONING_STEPS = (
GenerationMutationTemplatesEvolInstruct.INCREASED_REASONING_STEPS.value
)
GENERATION_MUTATION_TEMPLATES = {
"FRESH_START": GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT["FRESH_START"],
"CONSTRAINTS": GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT["CONSTRAINTS"],
"DEEPENING": GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT["DEEPENING"],
"CONCRETIZING": GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT["CONCRETIZING"],
"INCREASED_REASONING_STEPS": GENERATION_MUTATION_TEMPLATES_EVOL_INSTRUCT[
"INCREASED_REASONING_STEPS"
],
}
25 changes: 8 additions & 17 deletions src/distilabel/steps/task/evol_instruct/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,22 @@

import sys

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.utils.lists import flatten_responses

if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources

if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
from enum import EnumType

from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import numpy as np
from pydantic import Field, PrivateAttr
from typing_extensions import override

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.task.base import GeneratorTask
from distilabel.steps.task.evol_instruct.utils import GenerationMutationTemplates
from distilabel.steps.task.evol_instruct.utils import GENERATION_MUTATION_TEMPLATES
from distilabel.utils.lists import flatten_responses

if TYPE_CHECKING:
from distilabel.steps.task.typing import ChatType
Expand Down Expand Up @@ -66,7 +60,7 @@ class EvolInstructGenerator(GeneratorTask):

num_instructions: int
generate_answers: bool = False
mutation_templates: EnumType = Field(default=GenerationMutationTemplates)
mutation_templates: Dict[str, str] = GENERATION_MUTATION_TEMPLATES

min_length: RuntimeParameter[int] = Field(
default=512,
Expand Down Expand Up @@ -98,7 +92,7 @@ def _generate_seed_texts(self) -> List[str]:
for _ in range(self.num_instructions * 10):
num_words = np.random.choice([1, 2, 3, 4])
seed_texts.append(
self.mutation_templates.FRESH_START.value.replace( # type: ignore
self.mutation_templates["FRESH_START"].replace( # type: ignore
"<PROMPT>",
", ".join(
[
Expand Down Expand Up @@ -171,11 +165,8 @@ def format_output( # type: ignore

@property
def mutation_templates_names(self) -> List[str]:
"""Returns the names i.e. keys of the provided `mutation_templates` enum."""
return [
member.name # type: ignore
for member in self.mutation_templates.__members__.values() # type: ignore
]
"""Returns the names i.e. keys of the provided `mutation_templates`."""
return list(self.mutation_templates.keys())

def _apply_random_mutation(self, iter_no: int) -> List["ChatType"]:
"""Applies a random mutation from the ones provided as part of the `mutation_templates`
Expand All @@ -201,7 +192,7 @@ def _apply_random_mutation(self, iter_no: int) -> List["ChatType"]:
self._prompts[idx] = np.random.choice(self._seed_texts) # type: ignore

prompt_with_template = (
self.mutation_templates[mutation].value.replace( # type: ignore
self.mutation_templates[mutation].replace( # type: ignore
"<PROMPT>",
self._prompts[idx], # type: ignore
) # type: ignore
Expand Down
51 changes: 23 additions & 28 deletions src/distilabel/steps/task/evol_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

if sys.version_info < (3, 11):
from enum import Enum as StrEnum
else:
from enum import StrEnum


REWRITE_INSTRUCTION = """
I want you act as a Prompt Rewriter.\n
Expand All @@ -42,34 +35,36 @@
""".lstrip()


class MutationTemplates(StrEnum):
CONSTRAINTS = REWRITE_INSTRUCTION.format(
MUTATION_TEMPLATES = {
"CONSTRAINTS": REWRITE_INSTRUCTION.format(
"Please add one more constraints/requirements into '#The Given Prompt#'"
)
DEEPENING = REWRITE_INSTRUCTION.format(
),
"DEEPENING": REWRITE_INSTRUCTION.format(
"If #The Given Prompt# contains inquiries about certain issues, the depth and breadth of the inquiry can be increased."
)
CONCRETIZING = REWRITE_INSTRUCTION.format(
),
"CONCRETIZING": REWRITE_INSTRUCTION.format(
"Please replace general concepts with more specific concepts."
)
INCREASED_REASONING_STEPS = REWRITE_INSTRUCTION.format(
),
"INCREASED_REASONING_STEPS": REWRITE_INSTRUCTION.format(
"If #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning."
)
BREADTH = CREATE_INSTRUCTION
),
"BREADTH": CREATE_INSTRUCTION,
}


class GenerationMutationTemplates(StrEnum):
FRESH_START = "Write one question or request containing one or more of the following words: <PROMPT>"
CONSTRAINTS = REWRITE_INSTRUCTION.format(
GENERATION_MUTATION_TEMPLATES = {
"FRESH_START": "Write one question or request containing one or more of the following words: <PROMPT>",
"CONSTRAINTS": REWRITE_INSTRUCTION.format(
"Please add one more constraints/requirements into '#The Given Prompt#'"
)
DEEPENING = REWRITE_INSTRUCTION.format(
),
"DEEPENING": REWRITE_INSTRUCTION.format(
"If #The Given Prompt# contains inquiries about certain issues, the depth and breadth of the inquiry can be increased."
)
CONCRETIZING = REWRITE_INSTRUCTION.format(
),
"CONCRETIZING": REWRITE_INSTRUCTION.format(
"Please replace general concepts with more specific concepts."
)
INCREASED_REASONING_STEPS = REWRITE_INSTRUCTION.format(
),
"INCREASED_REASONING_STEPS": REWRITE_INSTRUCTION.format(
"If #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning."
)
BREADTH = CREATE_INSTRUCTION
),
"BREADTH": CREATE_INSTRUCTION,
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EvolComplexity,
)
from distilabel.steps.task.evol_instruct.evol_complexity.utils import (
MutationTemplates,
MUTATION_TEMPLATES,
)


Expand All @@ -31,5 +31,5 @@ def test_mutation_templates(self, dummy_llm: LLM) -> None:
assert task.name == "task"
assert task.llm is dummy_llm
assert task.num_evolutions == 2
assert task.mutation_templates == MutationTemplates
assert "BREADTH" not in task.mutation_templates.__members__
assert task.mutation_templates == MUTATION_TEMPLATES
assert "BREADTH" not in task.mutation_templates
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EvolComplexityGenerator,
)
from distilabel.steps.task.evol_instruct.evol_complexity.utils import (
GenerationMutationTemplates,
GENERATION_MUTATION_TEMPLATES,
)


Expand All @@ -31,5 +31,5 @@ def test_mutation_templates(self, dummy_llm: LLM) -> None:
assert task.name == "task"
assert task.llm is dummy_llm
assert task.num_instructions == 2
assert task.mutation_templates == GenerationMutationTemplates
assert "BREADTH" not in task.mutation_templates.__members__
assert task.mutation_templates == GENERATION_MUTATION_TEMPLATES
assert "BREADTH" not in task.mutation_templates
Loading

0 comments on commit 448ffdb

Please sign in to comment.