Skip to content

Commit

Permalink
Add ShardMapper prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Oct 18, 2023
1 parent f440ca4 commit e47f762
Show file tree
Hide file tree
Showing 18 changed files with 128 additions and 131 deletions.
37 changes: 22 additions & 15 deletions spacy_llm/tasks/builtin_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from ..compat import Self
from ..registry import lowercase_normalizer
from ..ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer
from ..ty import TaskResponseParser
from ..ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser


class BuiltinTask(abc.ABC):
Expand All @@ -35,7 +34,6 @@ def __init__(
prompt_example_type: Type[FewshotExample[Self]],
template: str,
prompt_examples: Optional[List[FewshotExample[Self]]],
n_token_estimator: NTokenEstimator,
shard_mapper: ShardMapper,
shard_reducer: ShardReducer,
):
Expand All @@ -44,15 +42,15 @@ def __init__(
prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples.
template (str): Prompt template passed to the model.
prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts.
n_token_estimator (NTokenEstimator): Estimates number of tokens in a string.
shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context.
shard_reducer (ShardReducer): Reduces doc shards back into one doc instance.
"""
self._parse_responses = parse_responses
self._prompt_examples = prompt_examples or []
self._template = template
self._prompt_example_type = prompt_example_type
self._n_token_estimator = n_token_estimator
self._shard_mapper = shard_mapper
self._shard_reducer = shard_reducer

def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]:
"""Generate prompts from docs.
Expand All @@ -61,17 +59,32 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]:
"""
environment = jinja2.Environment()
_template = environment.from_string(self._template)

def render_template(shard: Doc) -> str:
"""Renders template for a given doc (shard).
shard (Doc): Doc shard. Note that if the prompt is small enough to fit within the model's context window,
there will only be one shard, which is identical to the original doc.
RETURNS (str): Rendered template.
"""
return _template.render(
text=doc.text,
prompt_examples=self._prompt_examples,
**self._get_prompt_data(shard),
)

for doc in self._preprocess_docs_for_prompt(docs):
# todo make prompt data a doc-dependent function (worry about EL after this works for available tasks)
prompt = _template.render(
text=doc.text,
prompt_examples=self._prompt_examples,
**self._prompt_data,
**self._get_prompt_data(doc),
)
yield prompt

@property
def _prompt_data(self) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class.
def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data
returned by this might be static (i. e. the same for all doc shards) or dynamic (contingent on the doc shard).
shard (Doc): Doc (shard) for which prompt data should be fetched.
RETURNS (Dict[str, Any]): Data injected into prompt template.
"""
return {}
Expand Down Expand Up @@ -121,7 +134,6 @@ def get_cfg(self) -> Dict[str, Any]:

def set_cfg(self, cfg: Dict[str, Any]) -> None:
"""Deserialize the task's configuration attributes.
cfg (Dict[str, Any]): dictionary containing configuration attributes.
"""
for key, value in cfg.items():
Expand All @@ -134,7 +146,6 @@ def _get_prompt_examples(self) -> List[Dict[str, Any]]:

def _set_prompt_examples(self, examples: List[Dict[str, Any]]) -> None:
"""Set prompt examples.
examples (List[Dict[str, Any]]): prompt examples.
"""
self._prompt_examples = [
Expand Down Expand Up @@ -191,7 +202,6 @@ def to_disk(
path (Path): A path (currently unused).
exclude (Tuple): Names of properties to exclude from serialization.
"""

serialize = {
"cfg": lambda p: srsly.write_json(p, self.get_cfg()),
"prompt_examples": lambda p: srsly.write_msgpack(
Expand Down Expand Up @@ -252,7 +262,6 @@ def __init__(
prompt_example_type: Type[FewshotExample[Self]],
template: str,
prompt_examples: Optional[List[FewshotExample[Self]]],
n_token_estimator: NTokenEstimator,
shard_mapper: ShardMapper,
shard_reducer: ShardReducer,
labels: List[str],
Expand All @@ -265,7 +274,6 @@ def __init__(
prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples.
template (str): Prompt template passed to the model.
prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts.
n_token_estimator (NTokenEstimator): Estimates number of tokens in a string.
shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context.
shard_reducer (ShardReducer): Reduces doc shards back into one doc instance.
labels (List[str]): List of labels to pass to the template.
Expand All @@ -281,7 +289,6 @@ def __init__(
prompt_example_type=prompt_example_type,
template=template,
prompt_examples=prompt_examples,
n_token_estimator=n_token_estimator,
shard_mapper=shard_mapper,
shard_reducer=shard_reducer,
)
Expand Down
8 changes: 3 additions & 5 deletions spacy_llm/tasks/lemma/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional, Type

from ...registry import registry
from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer
from ...ty import ShardMapper, ShardReducer, TaskResponseParser
from ..util.sharding import make_n_token_estimator, make_shard_mapper
from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ..util.sharding import make_shard_mapper
from .parser import parse_responses_v1
from .task import DEFAULT_LEMMA_TEMPLATE_V1, LemmaTask
from .util import LemmaExample, reduce_shards_to_doc, score
Expand All @@ -30,7 +30,6 @@ def make_lemma_task(
parse_responses: Optional[TaskResponseParser[LemmaTask]] = None,
prompt_example_type: Optional[Type[FewshotExample]] = None,
examples: ExamplesConfigType = None,
n_token_estimator: Optional[NTokenEstimator] = None,
shard_mapper: Optional[ShardMapper] = None,
shard_reducer: Optional[ShardReducer] = None,
scorer: Optional[Scorer] = None,
Expand Down Expand Up @@ -58,7 +57,6 @@ def make_lemma_task(
parse_responses=parse_responses or parse_responses_v1,
prompt_example_type=example_type,
prompt_examples=lemma_examples,
n_token_estimator=n_token_estimator or make_n_token_estimator(),
shard_mapper=shard_mapper or make_shard_mapper(),
shard_reducer=shard_reducer or make_shard_reducer(),
scorer=scorer or score,
Expand Down
6 changes: 1 addition & 5 deletions spacy_llm/tasks/lemma/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from spacy.training import Example

from ...compat import Self
from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser
from ..builtin_task import BuiltinTask
from ..templates import read_template

Expand All @@ -20,7 +19,6 @@ def __init__(
prompt_example_type: Type[FewshotExample[Self]],
prompt_examples: Optional[List[FewshotExample[Self]]],
template: str,
n_token_estimator: NTokenEstimator,
shard_mapper: ShardMapper,
shard_reducer: ShardReducer,
scorer: Scorer,
Expand All @@ -31,7 +29,6 @@ def __init__(
prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples.
prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts.
template (str): Prompt template passed to the model.
n_token_estimator (NTokenEstimator): Estimates number of tokens in a string.
shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context.
shard_reducer (ShardReducer): Reduces doc shards back into one doc instance.
scorer (Scorer): Scorer function.
Expand All @@ -41,7 +38,6 @@ def __init__(
prompt_example_type=prompt_example_type,
template=template,
prompt_examples=prompt_examples,
n_token_estimator=n_token_estimator,
shard_mapper=shard_mapper,
shard_reducer=shard_reducer,
)
Expand Down
11 changes: 3 additions & 8 deletions spacy_llm/tasks/ner/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from ...compat import Literal
from ...registry import registry
from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer
from ...ty import ShardMapper, ShardReducer, TaskResponseParser
from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ...util import split_labels
from ..span import parse_responses as parse_span_responses
from ..span import parse_responses_cot as parse_span_responses_cot
from ..span.util import check_label_consistency, check_label_consistency_cot
from ..util.sharding import make_n_token_estimator, make_shard_mapper
from ..util.sharding import make_shard_mapper
from .task import DEFAULT_NER_TEMPLATE_V1, DEFAULT_NER_TEMPLATE_V2
from .task import DEFAULT_NER_TEMPLATE_V3, NERTask, SpanTask
from .util import NERCoTExample, NERExample, reduce_shards_to_doc, score
Expand Down Expand Up @@ -58,7 +58,6 @@ def make_ner_task(
labels=labels_list,
template=DEFAULT_NER_TEMPLATE_V1,
prompt_examples=span_examples,
n_token_estimator=make_n_token_estimator(),
shard_mapper=make_shard_mapper(),
shard_reducer=make_shard_reducer(),
normalizer=normalizer,
Expand Down Expand Up @@ -121,7 +120,6 @@ def make_ner_task_v2(
template=template,
label_definitions=label_definitions,
prompt_examples=span_examples,
n_token_estimator=make_n_token_estimator(),
shard_mapper=make_shard_mapper(),
shard_reducer=make_shard_reducer(),
normalizer=normalizer,
Expand All @@ -142,7 +140,6 @@ def make_ner_task_v3(
template: str = DEFAULT_NER_TEMPLATE_V3,
label_definitions: Optional[Dict[str, str]] = None,
examples: ExamplesConfigType = None,
n_token_estimator: Optional[NTokenEstimator] = None,
shard_mapper: Optional[ShardMapper] = None,
shard_reducer: Optional[ShardReducer] = None,
normalizer: Optional[Callable[[str], str]] = None,
Expand All @@ -166,7 +163,6 @@ def make_ner_task_v3(
full examples, although both can be provided.
examples (ExamplesConfigType): Optional callable that reads a file containing task examples for
few-shot learning. If None is passed, then zero-shot learning will be used.
n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string.
shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context.
shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance.
normalizer (Optional[Callable[[str], str]]): optional normalizer function.
Expand All @@ -188,7 +184,6 @@ def make_ner_task_v3(
template=template,
label_definitions=label_definitions,
prompt_examples=span_examples,
n_token_estimator=n_token_estimator or make_n_token_estimator(),
shard_mapper=shard_mapper or make_shard_mapper(),
shard_reducer=shard_reducer or make_shard_reducer(),
normalizer=normalizer,
Expand Down
6 changes: 1 addition & 5 deletions spacy_llm/tasks/ner/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from spacy.util import filter_spans

from ...compat import Literal, Self
from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser
from ..span import SpanTask
from ..span.task import SpanTaskLabelCheck
from ..templates import read_template
Expand All @@ -26,7 +25,6 @@ def __init__(
prompt_example_type: Type[FewshotExample[Self]],
label_definitions: Optional[Dict[str, str]],
prompt_examples: Optional[List[FewshotExample[Self]]],
n_token_estimator: NTokenEstimator,
shard_mapper: ShardMapper,
shard_reducer: ShardReducer,
normalizer: Optional[Callable[[str], str]],
Expand All @@ -44,7 +42,6 @@ def __init__(
template (str): Prompt template passed to the model.
parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task.
prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples.
n_token_estimator (NTokenEstimator): Estimates number of tokens in a string.
shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context.
shard_reducer (ShardReducer): Reduces doc shards back into one doc instance.
label_definitions (Optional[Dict[str, str]]): Map of label -> description
Expand All @@ -66,7 +63,6 @@ def __init__(
template=template,
parse_responses=parse_responses,
prompt_example_type=prompt_example_type,
n_token_estimator=n_token_estimator,
shard_mapper=shard_mapper,
shard_reducer=shard_reducer,
label_definitions=label_definitions,
Expand Down
9 changes: 3 additions & 6 deletions spacy_llm/tasks/rel/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable, Dict, List, Optional, Type, Union

from ...registry import registry
from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper
from ...ty import ShardReducer, TaskResponseParser
from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ...util import split_labels
from ..util.sharding import make_n_token_estimator, make_shard_mapper
from ..util.sharding import make_shard_mapper
from .examples import RELExample
from .parser import parse_responses_v1
from .task import DEFAULT_REL_TEMPLATE, RELTask
Expand All @@ -24,7 +24,6 @@ def make_rel_task(
prompt_example_type: Optional[Type[FewshotExample]] = None,
label_definitions: Optional[Dict[str, str]] = None,
examples: ExamplesConfigType = None,
n_token_estimator: Optional[NTokenEstimator] = None,
shard_mapper: Optional[ShardMapper] = None,
shard_reducer: Optional[ShardReducer] = None,
normalizer: Optional[Callable[[str], str]] = None,
Expand All @@ -46,7 +45,6 @@ def make_rel_task(
full examples, although both can be provided.
examples (ExamplesConfigType): Optional callable that reads a file containing task examples for
few-shot learning. If None is passed, then zero-shot learning will be used.
n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string.
shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context.
shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance.
normalizer (Optional[Callable[[str], str]]): Optional normalizer function.
Expand All @@ -64,7 +62,6 @@ def make_rel_task(
template=template,
label_definitions=label_definitions,
prompt_examples=rel_examples,
n_token_estimator=n_token_estimator or make_n_token_estimator(),
shard_mapper=shard_mapper or make_shard_mapper(),
shard_reducer=shard_reducer or make_shard_reducer(),
normalizer=normalizer,
Expand Down
34 changes: 15 additions & 19 deletions spacy_llm/tasks/rel/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from spacy.training import Example

from ...compat import Self
from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer
from ...ty import TaskResponseParser
from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser
from ..builtin_task import BuiltinTaskWithLabels
from ..templates import read_template
from .util import EntityItem, RelationItem
Expand All @@ -23,12 +22,22 @@ def __init__(
template: str,
label_definitions: Optional[Dict[str, str]],
prompt_examples: Optional[List[FewshotExample[Self]]],
n_token_estimator: NTokenEstimator,
shard_mapper: ShardMapper,
shard_reducer: ShardReducer,
normalizer: Optional[Callable[[str], str]],
verbose: bool,
):
super().__init__(
parse_responses=parse_responses,
prompt_example_type=prompt_example_type,
template=template,
prompt_examples=prompt_examples,
shard_mapper=shard_mapper,
shard_reducer=shard_reducer,
labels=labels,
label_definitions=label_definitions,
normalizer=normalizer,
)
"""Default REL task. Populates a `Doc._.rel` custom attribute.
parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task.
Expand All @@ -40,33 +49,20 @@ def __init__(
of the label to help the language model output the entities wanted.
It is usually easier to provide these definitions rather than
full examples, although both can be provided.
prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts.
n_token_estimator (NTokenEstimator): Estimates number of tokens in a string.
prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in
prompts.
shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context.
shard_reducer (ShardReducer): Reduces doc shards back into one doc instance.
normalizer (Optional[Callable[[str], str]]): Optional normalizer function.
verbose (bool): Controls the verbosity of the task.
"""
super().__init__(
parse_responses=parse_responses,
prompt_example_type=prompt_example_type,
template=template,
prompt_examples=prompt_examples,
n_token_estimator=n_token_estimator,
shard_mapper=shard_mapper,
shard_reducer=shard_reducer,
labels=labels,
label_definitions=label_definitions,
normalizer=normalizer,
)
self._verbose = verbose
self._field = "rel"

def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs]

@property
def _prompt_data(self) -> Dict[str, Any]:
def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]:
return {
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
Expand Down
Loading

0 comments on commit e47f762

Please sign in to comment.