diff --git a/.github/scripts/build_sdist_and_wheel.sh b/.github/scripts/build_sdist_and_wheel.sh new file mode 100755 index 000000000..ca770f5b7 --- /dev/null +++ b/.github/scripts/build_sdist_and_wheel.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Build sdist and wheel +python -m pip install -U pip +python -m pip install build +python -m build + +# Check sdist install and imports +mkdir -p test-sdist +cd test-sdist +python -m venv venv-sdist +venv-sdist/bin/python -m pip install ../dist/outlines-*.tar.gz +venv-sdist/bin/python -c "import outlines" +cd .. + +# Check wheel install and imports +mkdir -p test-wheel +cd test-wheel +python -m venv venv-wheel +venv-wheel/bin/python -m pip install ../dist/outlines-*.whl +venv-wheel/bin/python -c "import outlines" +cd .. diff --git a/.github/workflows/asv_benchmark_pr.yml b/.github/workflows/asv_benchmark_pr.yml new file mode 100644 index 000000000..90fb47423 --- /dev/null +++ b/.github/workflows/asv_benchmark_pr.yml @@ -0,0 +1,57 @@ +name: Benchmark PR + +on: + pull_request: + branches: [main] + workflow_dispatch: +env: + PYTHON_VERSION: "3.10" + WORKING_DIR: ${{ github.workspace }}/benchmarks + BENCHMARKS_OUTPUT: ${{ github.workspace }}/benchmarks_output + +jobs: + benchmark-pr: + runs-on: ubuntu-latest + if: contains(github.event.pull_request.labels.*.name, 'run_benchmarks') || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_run' + + defaults: + run: + working-directory: ${{ env.WORKING_DIR }} + + steps: + + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install asv virtualenv lf-asv-formatter + + - name: Create ASV machine config file + run: asv machine --machine gh-runner --yes + + - name: Run Benchmarks - `PR HEAD` vs `main` + run: | + # prepare main branch for comparison + git remote add upstream https://github.com/${{ github.repository }}.git + git fetch upstream main + + # Run benchmarks, allow errors, they will be caught in the next step + asv continuous upstream/main HEAD \ + --no-stats --interleave-rounds -a repeat=3 || true + + - name: BENCHMARK RESULTS + run: | + asv compare --factor=1.1 --no-stats --split upstream/main HEAD | tee ${{ env.BENCHMARKS_OUTPUT }} + if grep -q "Benchmarks that have got worse" "${{ env.BENCHMARKS_OUTPUT }}"; then + echo "Performance degradation detected!" + exit 1 + fi diff --git a/.github/workflows/release_pypi.yaml b/.github/workflows/release_pypi.yaml index 0006e74f2..9f78cfc43 100644 --- a/.github/workflows/release_pypi.yaml +++ b/.github/workflows/release_pypi.yaml @@ -15,28 +15,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: "3.10" - - name: Build sdist and wheel - run: | - python -m pip install -U pip - python -m pip install build - python -m build + - name: Build SDist and Wheel + run: ./.github/scripts/build_sdist_and_wheel.sh - name: Check that the package version matches the Release name run: | grep -Rq "^Version: ${GITHUB_REF:10}$" outlines.egg-info/PKG-INFO - - name: Check sdist install and imports - run: | - mkdir -p test-sdist - cd test-sdist - python -m venv venv-sdist - venv-sdist/bin/python -m pip install ../dist/outlines-*.tar.gz - venv-sdist/bin/python -c "import outlines" - - name: Check wheel install and imports - run: | - mkdir -p test-wheel - cd test-wheel - python -m venv venv-wheel - venv-wheel/bin/python -m pip install ../dist/outlines-*.whl - venv-wheel/bin/python -c "import outlines" - name: Publish to PyPi uses: pypa/gh-action-pypi-publish@v1.4.2 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b8d4208af..10879c78f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -88,3 +88,11 @@ jobs: name: html-report path: htmlcov if: ${{ failure() }} + + build-wheel: + name: Build Wheel and Test SDist + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build SDist and Wheel + run: ./.github/scripts/build_sdist_and_wheel.sh diff --git a/.gitignore b/.gitignore index 9e95a8732..9add6d8c4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ docs/build .idea/ *.gguf .venv +benchmarks/results diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fe83f89a..b528f0e8e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,3 +30,4 @@ repos: - id: mypy args: [--allow-redefinition] exclude: ^examples/ + additional_dependencies: [types-tqdm] diff --git a/README.md b/README.md index 4c3443639..e6356e2f3 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ First time here? Go to our [setup guide](https://outlines-dev.github.io/outlines - [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)! -Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][twitter] to stay up to date! +Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][dottxt-twitter] to stay up to date! ## Why should I use structured generation? diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json new file mode 100644 index 000000000..f57db9a0b --- /dev/null +++ b/benchmarks/asv.conf.json @@ -0,0 +1,20 @@ +{ + "version": 1, + "project": "Outlines", + "project_url": "https://outlines-dev.github.io/outlines/", + "repo": "..", + "branches": [ + "HEAD" + ], + "build_command": [ + "python -mpip install .[test]", + "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}", + ], + "environment_type": "virtualenv", + "show_commit_url": "https://github.com/outlines-dev/outlines/commit/", + "benchmark_dir": ".", + "env_dir": "env", + "results_dir": "results", + "html_dir": "html", + "build_cache_size": 8 +} diff --git a/tests/benchmark/test_benchmark_json_schema.py b/benchmarks/bench_json_schema.py similarity index 63% rename from tests/benchmark/test_benchmark_json_schema.py rename to benchmarks/bench_json_schema.py index 33f3f5b16..8d1ceeb24 100644 --- a/tests/benchmark/test_benchmark_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,11 +1,8 @@ -import pytest +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema -import outlines - -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 -from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 +from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 simple_schema = """{ "$defs": { @@ -63,30 +60,22 @@ "required": ["id", "work", "recording_artists"] }""" - schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_regex(benchmark, ensure_numba_compiled, schema_name): - """Benchmark convert json schema to regex""" - schema = schemas[schema_name] - benchmark.pedantic( - build_regex_from_schema, - args=(schema,), - rounds=8, - ) +class JsonSchemaBenchmark: + params = schemas.keys() + + def setup(self, schema_name): + self.tokenizer = setup_tokenizer() + self.schema = schemas[schema_name] + ensure_numba_compiled(self.tokenizer) + @cache_disabled() + def time_json_schema_to_regex(self, schema_name): + build_regex_from_schema(self.schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, schema_name -): - """Benchmark compile json schema as FSM""" - schema = schemas[schema_name] - regex = build_regex_from_schema(schema) - benchmark.pedantic( - RegexGuide, - args=(regex, tokenizer), - rounds=8, - ) + @cache_disabled() + def time_json_schema_to_fsm(self, schema_name): + regex = build_regex_from_schema(self.schema) + RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py new file mode 100644 index 000000000..2713707e5 --- /dev/null +++ b/benchmarks/bench_numba_compile.py @@ -0,0 +1,34 @@ +import importlib + +import interegular +import numba + +from outlines.caching import cache_disabled +from outlines.fsm import regex + +from .common import setup_tokenizer + + +class NumbaCompileBenchmark: + def setup(self): + self.tokenizer = setup_tokenizer() + self.regex = regex + original_njit = numba.njit + + def mock_njit(*args, **kwargs): + kwargs["cache"] = False + return original_njit(*args, **kwargs) + + self.original_njit = original_njit + numba.njit = mock_njit + importlib.reload(self.regex) + self.regex_pattern, _ = self.regex.make_deterministic_fsm( + interegular.parse_pattern("a").to_fsm().reduce() + ) + + def teardown(self): + numba.njit = self.original_njit + + @cache_disabled() + def time_compile_numba(self): + self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/tests/benchmark/test_benchmark_regex_fsm.py b/benchmarks/bench_regex_guide.py similarity index 66% rename from tests/benchmark/test_benchmark_regex_fsm.py rename to benchmarks/bench_regex_guide.py index e9e45052a..099f94df2 100644 --- a/tests/benchmark/test_benchmark_regex_fsm.py +++ b/benchmarks/bench_regex_guide.py @@ -1,10 +1,7 @@ -import pytest +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide -import outlines - -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 +from .common import ensure_numba_compiled, setup_tokenizer regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -19,14 +16,27 @@ } -@pytest.mark.parametrize("regex_name", regex_samples.keys()) -def test_benchmark_regex_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, regex_name -): - """Benchmark converting regex to FSM""" - regex_str = regex_samples[regex_name] - benchmark.pedantic( - RegexGuide, - args=(regex_str, tokenizer), - rounds=8, - ) +class RegexGuideBenchmark: + params = regex_samples.keys() + + def setup(self, pattern_name): + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + @cache_disabled() + def time_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) + + +class MemoryRegexGuideBenchmark: + params = ["simple_phone", "complex_span_constrained_relation_extraction"] + + def setup(self, pattern_name): + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + @cache_disabled() + def peakmem_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) diff --git a/tests/benchmark/conftest.py b/benchmarks/common.py similarity index 83% rename from tests/benchmark/conftest.py rename to benchmarks/common.py index 902d5d6eb..7d999ea9b 100644 --- a/tests/benchmark/conftest.py +++ b/benchmarks/common.py @@ -1,17 +1,14 @@ -import pytest from transformers import AutoTokenizer from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -@pytest.fixture -def tokenizer(): +def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) -@pytest.fixture def ensure_numba_compiled(tokenizer): RegexGuide("a", tokenizer) return True diff --git a/docs/community/contribute.md b/docs/community/contribute.md index 1df15084a..b336eacad 100644 --- a/docs/community/contribute.md +++ b/docs/community/contribute.md @@ -39,7 +39,7 @@ source .venv/bin/activate Then install the dependencies in editable mode, and install the pre-commit hooks: ```python -pip install -e .[test] +pip install -e ".[test]" pre-commit install ``` @@ -57,12 +57,38 @@ And run the code style checks: pre-commit run --all-files ``` -When modifying the code related to the index compilation, we kindly ask you to -post benchmarks before and after your changes. You can run benchmarks using: +### Benchmarking -```python -pytest --benchmark-only +Outlines uses [asv](https://asv.readthedocs.io) for automated benchmark testing. Benchmarks are run automatically before pull requests are merged to prevent performance degredation. + +You can run the benchmark test suite locally with the following command: +``` +asv run --config benchmarks/asv.conf.json +``` + +Run a specific test: +``` +asv run --config benchmarks/asv.conf.json -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm +``` + +Profile a specific test: ``` +asv run --config benchmarks/asv.conf.json --profile -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm +``` + +Compare to `origin/main` +``` +get fetch origin +asv continuous origin/main HEAD --config benchmarks/asv.conf.json +``` + +#### ASV PR Behavior + +- **View ASV Benchmark Results:** Open the workflow, view `BENCHMARK RESULTS` section. +- Merging is blocked unless benchmarks are run for the latest commit. +- Benchmarks fail if performance degrades by more than 10% for any individual benchmark. +- The "Benchmark PR" workflow runs when its manually dispatched, or if the `run_benchmarks` label is added to the PR they run for every commit. + ### Contribute to the documentation diff --git a/docs/community/versioning.md b/docs/community/versioning.md new file mode 100644 index 000000000..d64a56e7f --- /dev/null +++ b/docs/community/versioning.md @@ -0,0 +1,26 @@ +--- +title: Versioning Guide +--- + +# Versioning Guide + + +The Outlines project follows a structured versioning scheme designed to provide clarity and minimize risk for downstream dependents. + +Each part of the version number (`major.minor.patch`) conveys information about the nature and impact of the changes included in the release. + +- **Major Releases** includes compatibility-breaking changes to core interfaces, such as `LogitsProcessor`s and `Guides`. +- **Minor Releases** introduce changes of substance to internal or unexposed functionality. These changes are well tested and intended to maintain compatability with existing use of core interfaces. +- **Patch Releases** address bug fixes and incorporate low-risk changes to improve stability and performance. + +## Releases + +Releases along with release notes can be found on the [Outlines Releases GitHub Page](https://github.com/outlines-dev/outlines/releases). + +## Version Pinning Recommendations + +Here are our recommendations for managing dependencies on the Outlines package: + +**Small, Risk-Tolerant Projects:** Pin to a specific major version. + +**Large, Conservative Projects:** Pin to a specific minor version. diff --git a/docs/cookbook/deploy-using-bentoml.md b/docs/cookbook/deploy-using-bentoml.md index f3770d07d..6bee77441 100644 --- a/docs/cookbook/deploy-using-bentoml.md +++ b/docs/cookbook/deploy-using-bentoml.md @@ -50,13 +50,14 @@ $ bentoml models list Tag Module Size Creation Time mistralai--mistral-7b-v0.1:m7lmf5ac2cmubnnz 13.49 GiB 2024-04-25 06:52:39 +``` + ## Define a BentoML Service As the model is ready, we can define a [BentoML Service](https://docs.bentoml.com/en/latest/guides/services.html) to wrap the capabilities of the model. We will run the JSON-structured generation example [in the README](https://github.com/outlines-dev/outlines?tab=readme-ov-file#efficient-json-generation-following-a-json-schema), with the following schema: - ```python DEFAULT_SCHEMA = """{ "title": "Character", @@ -206,6 +207,8 @@ Expected output: "weapon": "sword", "strength": 20 } +``` + ## Deploy to BentoCloud After the Service is ready, you can deploy it to [BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/get-started.html) for better management and scalability. [Sign up](https://cloud.bentoml.com/signup) if you haven't got a BentoCloud account. diff --git a/docs/reference/json.md b/docs/reference/json.md index 3b5976f19..85e1a846a 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -36,10 +36,10 @@ print(result) !!! Note "JSON and whitespaces" - By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string: + By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: ```python - generator = generate.json(model, User, whitespace_pattern="") + generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` !!! Note "Performance" diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md index 5ddd4a457..07357a360 100644 --- a/docs/reference/models/openai.md +++ b/docs/reference/models/openai.md @@ -19,7 +19,7 @@ print(type(model)) Outlines also supports Azure OpenAI models: -``` +```python from outlines import models model = models.azure_openai( @@ -30,7 +30,7 @@ model = models.azure_openai( More generally, you can use any API client compatible with the OpenAI interface by passing an instance of the client, a configuration, and optionally the corresponding tokenizer (if you want to be able to use `outlines.generate.choice`): -``` +```python from openai import AsyncOpenAI import tiktoken diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md index 7fc29f00c..b5221e582 100644 --- a/docs/reference/models/vllm.md +++ b/docs/reference/models/vllm.md @@ -123,7 +123,7 @@ from outlines import models, generate model = models.vllm("mistralai/Mistral-7b-v0.1") generator = generate.text(model) -params = SamplingParams(n=2, frequence_penalty=1., min_tokens=2) +params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2) answer = generator("A prompt", sampling_params=params) ``` diff --git a/docs/reference/prompting.md b/docs/reference/prompting.md index a7731ba0f..34860fce0 100644 --- a/docs/reference/prompting.md +++ b/docs/reference/prompting.md @@ -223,7 +223,7 @@ Several projects (e.g.[Toolformer](https://arxiv.org/abs/2302.04761), [ViperGPT] Can you do something? COMMANDS - 1. my_tool: Tool description, args: arg1:str, arg2:int + 1. my_tool: Tool description., args: arg1: str, arg2: int def my_tool(arg1: str, arg2: int): """Tool description. diff --git a/docs/reference/types.md b/docs/reference/types.md index 645249263..2e02f45f1 100644 --- a/docs/reference/types.md +++ b/docs/reference/types.md @@ -2,11 +2,8 @@ Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions: - | Category | Type | Import | Description | |:--------:|:----:|:-------|:------------| -| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes | -| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers | | ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct | | Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] | | Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] | @@ -14,6 +11,24 @@ Outlines provides custom Pydantic types so you can focus on your use case rather | | numeric code | `outlines.types.countries.Numeric` | Valid [country numeric codes][wiki-country-numeric] | | | name | `outlines.types.countries.Name` | Valid country names | | | flag | `outlines.types.countries.Flag` | Valid flag emojis | +| | email | `outlines.types.Email` | Valid email address | + +Some types require localization. We currently only support US types, but please don't hesitate to create localized versions of the different types and open a Pull Request. Localized types are specified using `types.locale` in the following way: + +```python +from outlines import types + +types.locale("us").ZipCode +types.locale("us").PhoneNumber +``` + +Here are the localized types that are currently available: + +| Category | Locale | Import | Description | +|:--------:|:----:|:-------|:------------| +| Zip code | US | `ZipCode` | Generate US Zip(+4) codes | +| Phone number | US | `PhoneNumber` | Generate valid US phone numbers | + You can use these types in Pydantic schemas for JSON-structured generation: @@ -22,11 +37,13 @@ from pydantic import BaseModel from outlines import models, generate, types +# Specify the locale for types +locale = types.locale("us") class Client(BaseModel): name: str - phone_number: types.PhoneNumber - zip_code: types.ZipCode + phone_number: locale.PhoneNumber + zip_code: locale.ZipCode model = models.transformers("mistralai/Mistral-7B-v0.1") @@ -47,7 +64,7 @@ from outlines import models, generate, types model = models.transformers("mistralai/Mistral-7B-v0.1") -generator = generate.format(model, types.PhoneNumber) +generator = generate.format(model, types.locale("us").PhoneNumber) result = generator( "Return a US Phone number: " ) diff --git a/mkdocs.yml b/mkdocs.yml index c4a9edcbc..01e8506ab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -143,4 +143,5 @@ nav: - Chat with us ☕: https://discord.com/invite/R9DSu34mGd - How to contribute 🏗️: community/contribute.md - Your projects 👏: community/examples.md + - Versioning Guide 📌: community/versioning.md - Blog: blog/index.md diff --git a/outlines/caching.py b/outlines/caching.py index 68207a0e4..95392c7e8 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,15 +1,41 @@ import asyncio +import contextlib import functools -import hashlib import os from typing import Callable, Optional import cloudpickle -from diskcache import Cache +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name _caching_enabled = True +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + @functools.lru_cache(1) def get_cache(): """Get the context object that contains previously-computed return values. @@ -26,7 +52,12 @@ def get_cache(): home_dir = os.path.expanduser("~") cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") - memory = Cache(cache_dir, eviction_policy="none", cull_limit=0) + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) # ensure if version upgrade occurs, old cache is pruned if outlines_version != memory.get("__version__"): @@ -36,63 +67,72 @@ def get_cache(): return memory -def hash_arguments(*args, **kwargs) -> str: - """Create a hash out of the args and kwargs provided""" - result = hashlib.md5() - for item in list(args) + sorted(kwargs.items()): - result.update(cloudpickle.dumps(item)) - return result.hexdigest() - - -def cache(key_function: Optional[Callable] = None): +def cache(expire: Optional[float] = None, typed=False, ignore=()): """Caching decorator for memoizing function calls. + The cache key is created based on the values returned by the key_function callable if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + Parameters ---------- - key_function - A callable function used to generate a unique key for each function call. It's - called with the arguments of the decorated function as arguments + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + Returns ------- - A decorator function that can be applied to other functions. + A decorator function that can be applied to other functions. """ def decorator(cached_function: Callable): memory = get_cache() - def wrapper(*args, **kwargs): - if not _caching_enabled: - return cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = cached_function(*args, **kwargs) - memory[cache_key] = result - return result - - async def async_wrapper(*args, **kwargs): - if not _caching_enabled: - return await cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = await cached_function(*args, **kwargs) - memory[cache_key] = result - return result + base = (full_name(cached_function),) if asyncio.iscoroutinefunction(cached_function): - return async_wrapper + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + else: - return wrapper + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper return decorator @@ -125,3 +165,15 @@ def clear_cache(): """Erase the cache completely.""" memory = get_cache() memory.clear() + + +@contextlib.contextmanager +def cache_disabled(): + # outlines.caching._caching_enabled + global _caching_enabled + original_state = _caching_enabled + _caching_enabled = False + try: + yield + finally: + _caching_enabled = original_state diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 5c7b56326..d247db62b 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -105,44 +105,44 @@ def copy(self): return self +@cache() +def create_states_mapping( + regex_string: str, tokenizer: "Tokenizer" +) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 def __init__(self, regex_string: str, tokenizer): - @cache() - def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - regex_pattern = interegular.parse_pattern(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - ( self.states_to_token_maps, self.empty_token_ids, fsm_finals, - ) = create_states_mapping(regex_string) + ) = create_states_mapping(regex_string, tokenizer) self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} @@ -193,12 +193,8 @@ def get_next_state(self, state: int, token_id: int) -> int: The new state of the guide. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state not in self.states_to_token_maps: return -1 - elif ( - state in self.final_states - ): # Necessary because we keep generating EOS tokens when finished - return state last_token_to_end_state = self.states_to_token_maps[state] next_state = last_token_to_end_state.get(token_id) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 647c95a22..3bd4816a9 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,6 +1,7 @@ import inspect import json import re +import warnings from typing import Callable, Optional from jsonschema.protocols import Validator @@ -9,13 +10,13 @@ from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 -STRING_INNER = r'([^("\\\x00-\x1f\x7f-\x9f)]|\\\\)' +STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\\)' STRING = f'"{STRING_INNER}*"' INTEGER = r"(-)?(0|[1-9][0-9]*)" NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[\n ]*" +WHITESPACE = r"[ ]?" type_to_regex = { "string": STRING, @@ -125,7 +126,22 @@ def to_regex( if whitespace_pattern is None: whitespace_pattern = WHITESPACE - if "properties" in instance: + if instance == {}: + # JSON Schema Spec: Empty object means unconstrained, any json type is legal + types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + {"type": "array"}, + {"type": "object"}, + ] + regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + regexes = [rf"({r})" for r in regexes] + return rf"{'|'.join(regexes)}" + + elif "properties" in instance: regex = "" regex += r"\{" properties = instance["properties"] @@ -194,16 +210,19 @@ def to_regex( to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] ] - xor_patterns = [] - # json schema validation ensured there is no overlapping schemas in oneOf - for subregex in subregexes: - other_subregexes = filter(lambda r: r != subregex, subregexes) - other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) - negative_lookahead = f"(?!.*({other_subregexes_str}))" - xor_patterns.append(f"({subregex}){negative_lookahead}") + xor_patterns = [f"(?:{subregex})" for subregex in subregexes] return rf"({'|'.join(xor_patterns)})" + # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx + elif "prefixItems" in instance: + element_patterns = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] + ] + comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" + tuple_inner = comma_split_pattern.join(element_patterns) + return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" + # The enum keyword is used to restrict a value to a fixed set of values. It # must be an array with at least one element, where each element is unique. elif "enum" in instance: @@ -293,15 +312,22 @@ def to_regex( # Here we need to make the choice to exclude generating list of objects # if the specification of the object is not given, even though a JSON # object that contains an object here would be valid under the specification. - types = [ + legal_types = [ {"type": "boolean"}, {"type": "null"}, {"type": "number"}, {"type": "integer"}, {"type": "string"}, ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] - return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]" + depth = instance.get("depth", 2) + if depth > 0: + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + + regexes = [ + to_regex(resolver, t, whitespace_pattern) for t in legal_types + ] + return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" elif instance_type == "object": # pattern for json object with values defined by instance["additionalProperties"] @@ -317,8 +343,30 @@ def to_regex( allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" + additional_properties = instance.get("additionalProperties") + + if additional_properties is None or additional_properties is True: + # JSON Schema behavior: If the additionalProperties of an object is + # unset or True, it is unconstrained object. + # We handle this by setting additionalProperties to anyOf: {all types} + + legal_types = [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "null"}, + ] + + # We set the object depth to 2 to keep the expression finite, but the "depth" + # key is not a true component of the JSON Schema specification. + depth = instance.get("depth", 2) + if depth > 0: + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + additional_properties = {"anyOf": legal_types} + value_pattern = to_regex( - resolver, instance["additionalProperties"], whitespace_pattern + resolver, additional_properties, whitespace_pattern ) key_value_pattern = ( f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" @@ -375,6 +423,14 @@ def get_schema_from_signature(fn: Callable) -> str: else: arguments[name] = (arg.annotation, ...) - model = create_model("Arguments", **arguments) + try: + fn_name = fn.__name__ + except Exception as e: + fn_name = "Arguments" + warnings.warn( + f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", + category=UserWarning, + ) + model = create_model(fn_name, **arguments) return model.model_json_schema() diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 0941bbb9f..b68e31897 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -26,6 +26,7 @@ anything_else, ) from numba.typed.typedobjectutils import _nonoptional +from tqdm import tqdm if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -692,6 +693,12 @@ def create_fsm_index_end_to_end( seen: Set[int] = set() next_states = {fsm_info.initial} + pbar = tqdm( + total=len(set(fsm_info.transitions.values())) + + 1, # all transitions plus initial + desc="Compiling FSM index for all state transitions", + ) + while next_states: start_state = next_states.pop() @@ -713,7 +720,11 @@ def create_fsm_index_end_to_end( if end_state not in seen: next_states.add(end_state) - seen.add(start_state) + if start_state not in seen: + pbar.update(1) + seen.add(start_state) + + pbar.close() return states_to_token_subsets diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3f4f182d2..51a995664 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: return generated_token_ids = sequence.token_ids[:, -num_generated:] generated_sequences = self.tokenizer.decode(generated_token_ids) - next_tokens = [ - token[len(sequence) :] if not stop else "" - for token, sequence, stop in zip( - generated_sequences, - previously_generated_sequences, - is_stop_at_reached, - ) - ] - previously_generated_sequences = generated_sequences if stop_sequences: is_stop_at_reached = [ stop @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ) ] + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() for i in range(batch_size): diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index edda617d0..e506aa035 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -86,8 +86,9 @@ def sequence_generator( token_ids = update_token_ids(token_ids, next_token_ids, ancestors) attention_masks = update_attention_masks(attention_masks, ancestors) kv_cache = reorder_kv_cache(kv_cache, ancestors) - fsms = reorder_fsms(fsms, ancestors) - fsm_states = reorder_fsm_states(fsm_states, ancestors) + if len(ancestors) > 1: + fsms = reorder_fsms(fsms, ancestors) + fsm_states = reorder_fsm_states(fsm_states, ancestors) fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) is_finished = is_generation_finished(fsms, fsm_states) diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 8c000a6e5..8e18c33e7 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -26,7 +26,7 @@ """ import math -from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import numpy as np import torch @@ -36,29 +36,12 @@ from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer if TYPE_CHECKING: from llama_cpp import Llama -class LlamaCppTokenizer: - def __init__(self, model: "Llama"): - self.eos_token_id = model.token_eos() - self.eos_token = model.tokenizer().decode([self.eos_token_id]) - self.pad_token_id = self.eos_token_id - self.special_tokens: Set[int] = set() - - self.vocabulary: Dict[str, int] = dict() - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - self.decode = model.tokenizer().decode - - def convert_token_to_string(self, token: str) -> str: - return token - - class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 8a6a53a27..840e1364f 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,14 +1,102 @@ import dataclasses -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union +import pickle +import warnings +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + # ensure stable ordering of vocabulary + self.vocabulary = { + tok: tok_id + for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1]) + } + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if not isinstance(other, LlamaCppTokenizer): + return False + return self.__getstate__() == other.__getstate__() + + def __hash__(self): + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + self.vocabulary, + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + class LlamaCppParams(TypedDict, total=False): suffix: Optional[str] temperature: float @@ -288,6 +376,16 @@ def llamacpp( if "verbose" not in llamacpp_model_params: llamacpp_model_params["verbose"] = False + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + if "tokenizer" not in llamacpp_model_params: + warnings.warn( + "The pre-tokenizer in `llama.cpp` handles unicode improperly " + + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" + + "Outlines may raise a `RuntimeError` when building the regex index.\n" + + "To circumvent this error when using `models.llamacpp()` you may pass the argument" + + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" + ) + model = Llama.from_pretrained(repo_id, filename, **llamacpp_model_params) return LlamaCpp(model) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 3bc59412e..fae9b8e74 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from datasets.fingerprint import Hasher + from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -109,10 +111,15 @@ def __eq__(self, other): return NotImplemented def __hash__(self): - from datasets.fingerprint import Hasher - return hash(Hasher.hash(self.tokenizer)) + def __getstate__(self): + state = {"tokenizer": self.tokenizer} + return state + + def __setstate__(self, state): + self.__init__(state["tokenizer"]) + class Transformers: """Represents a `transformers` model.""" diff --git a/outlines/prompts.py b/outlines/prompts.py index b4e7288bb..01e900c96 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -207,6 +207,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: env.filters["source"] = get_fn_source env.filters["signature"] = get_fn_signature env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args jinja_template = env.from_string(cleaned_template) @@ -226,6 +227,18 @@ def get_fn_name(fn: Callable): return name +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + arg_str_list = [str(param) for param in signature.parameters.values()] + arg_str = ", ".join(arg_str_list) + return arg_str + + def get_fn_description(fn: Callable): """Returns the first line of a callable's docstring.""" if not callable(fn): diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 266d3a68e..f4d2b8cd3 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,4 +1,4 @@ from . import airports, countries +from .email import Email from .isbn import ISBN -from .phone_numbers import PhoneNumber -from .zip_codes import ZipCode +from .locales import locale diff --git a/outlines/types/email.py b/outlines/types/email.py new file mode 100644 index 000000000..45f8c4b2c --- /dev/null +++ b/outlines/types/email.py @@ -0,0 +1,11 @@ +"""Email Address types.""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Taken from StackOverflow +# https://stackoverflow.com/a/201378/14773537 +EMAIL_REGEX = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" +Email = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": EMAIL_REGEX}), +] diff --git a/outlines/types/locales.py b/outlines/types/locales.py new file mode 100644 index 000000000..c5d251bae --- /dev/null +++ b/outlines/types/locales.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from outlines.types.phone_numbers import USPhoneNumber +from outlines.types.zip_codes import USZipCode + + +@dataclass +class US: + ZipCode = USZipCode + PhoneNumber = USPhoneNumber + + +def locale(locale_str: str): + locales = {"us": US} + + if locale_str not in locales: + raise NotImplementedError( + f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request." + ) + + return locales[locale_str] diff --git a/outlines/types/phone_numbers.py b/outlines/types/phone_numbers.py index 0b27c7890..618687e75 100644 --- a/outlines/types/phone_numbers.py +++ b/outlines/types/phone_numbers.py @@ -10,7 +10,7 @@ US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" -PhoneNumber = Annotated[ +USPhoneNumber = Annotated[ str, WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), ] diff --git a/outlines/types/zip_codes.py b/outlines/types/zip_codes.py index 981efdd03..67d994d5c 100644 --- a/outlines/types/zip_codes.py +++ b/outlines/types/zip_codes.py @@ -10,4 +10,4 @@ US_ZIP_CODE = r"\d{5}(?:-\d{4})?" -ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] +USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/pyproject.toml b/pyproject.toml index 0b6ae6352..c6f72d3e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,10 @@ dependencies = [ "referencing", "jsonschema", "requests", + "tqdm", + "datasets", + "pycountry", + "pyairports", ] dynamic = ["version"] @@ -49,7 +53,6 @@ test = [ "diff-cover", "accelerate", "beartype<0.16.0", - "datasets", "responses", "llama-cpp-python", "huggingface_hub", @@ -57,8 +60,6 @@ test = [ "vllm", "torch", "transformers", - "pycountry", - "pyairports", ] serve = [ "vllm>=0.3.0", @@ -93,7 +94,9 @@ filterwarnings = [ "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", + "ignore::FutureWarning:huggingface_hub.*", "ignore::UserWarning", + "ignore::DeprecationWarning:pyairports.*", ] [tool.mypy] diff --git a/tests/benchmark/test_benchmark_numba_compile.py b/tests/benchmark/test_benchmark_numba_compile.py deleted file mode 100644 index 827d561bd..000000000 --- a/tests/benchmark/test_benchmark_numba_compile.py +++ /dev/null @@ -1,33 +0,0 @@ -import importlib - -import interegular -import numba - -import outlines - -outlines.disable_cache() - - -def test_benchmark_compile_numba(benchmark, tokenizer, mocker): - """Compile a basic regex to benchmark the numba compilation time""" - - def setup(): - from outlines.fsm import regex - - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - mocker.patch("numba.njit", new=mock_njit) - importlib.reload(regex) - - regex_pattern, _ = regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - return (regex, regex_pattern, tokenizer), {} - - benchmark.pedantic( - lambda r, *args: r.create_fsm_index_tokenizer(*args), rounds=2, setup=setup - ) diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 30047b83d..97d120f20 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -82,9 +82,7 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) state = fsm.next_state(state=5, token_id=103) - assert state == 5 - - assert fsm.is_final_state(-1) + assert fsm.is_final_state(state) def test_cfg(): diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 28645f012..68f7e33ee 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -180,9 +180,7 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) state = fsm.get_next_state(state=5, token_id=103) - assert state == 5 - - assert fsm.is_final_state(-1) + assert fsm.is_final_state(state) def test_cfg(): diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5b3ad9e39..e691db374 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,9 +1,10 @@ import json import re -from typing import List +from typing import List, Literal, Union +import interegular import pytest -from pydantic import BaseModel, constr +from pydantic import BaseModel, Field, constr from outlines.fsm.json_schema import ( BOOLEAN, @@ -118,6 +119,8 @@ def test_match_number(pattern, does_match): STRING, [ ("unquotedstring", False), + ('"(parenthesized_string)"', True), + ('"malformed) parenthesis (((() string"', True), ('"quoted_string"', True), (r'"escape_\character"', False), (r'"double_\\escape"', True), @@ -212,8 +215,8 @@ def test_match_number(pattern, does_match): "properties": {"count": {"title": "Count", "type": "integer"}}, "required": ["count"], }, - '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}', - [('{\n "count": 100\n}', True)], + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', + [('{ "count": 100 }', True)], ), # array ( @@ -274,7 +277,7 @@ def test_match_number(pattern, does_match): rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar"\n}}""", True), + ("""{ "test_dict":{"foo":"bar" }}""", True), ("""{ "test_dict":{}}""", True), ("""{ "WRONG_KEY":{}}""", False), ("""{ "test_dict":{"wrong_type" 1}}""", False), @@ -321,7 +324,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], }, - rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))", + rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', [ ("12.3", True), ("true", True), @@ -351,6 +354,15 @@ def test_match_number(pattern, does_match): rf"({STRING}{INTEGER})", [('"a"1', True), ('"a"', False), ('"1"', False)], ), + # Tuple / prefixItems + ( + { + "title": "Foo", + "prefixItems": [{"type": "string"}, {"type": "integer"}], + }, + rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", + [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], + ), # Nested schema ( { @@ -366,8 +378,8 @@ def test_match_number(pattern, does_match): }, "required": ["fuzz"], }, - f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', - [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], + f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', + [('{ "fuzz": { "spam": 100 }}', True)], ), # Schema with a reference ( @@ -381,7 +393,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "a"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], ), ( @@ -396,7 +408,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "name2"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], ), ( @@ -438,7 +450,7 @@ def test_match_number(pattern, does_match): } }, }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', [ ( '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', @@ -459,7 +471,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "weapon" : "sword" }', True), @@ -479,7 +491,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" , "weapon" : "sword" }', True), ( @@ -503,7 +515,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', [ ( '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', @@ -527,7 +539,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), @@ -707,6 +719,60 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), + # Unconstrained Object + ( + { + "title": "Foo", + "type": "object", + }, + [ + ("{}", True), + ('{"a": 1, "b": null}', True), + ('{"a": {"z": {"g": 4}}, "b": null}', True), + ("1234", False), # not an object + ('["a", "a"]', False), # not an array + ], + ), + # Unconstrained Array + ( + { + "type": "array", + }, + [ + ("[1, {}, false]", True), + ("[{}]", True), + ('[{"a": {"z": "q"}, "b": null}]', True), + ('[{"a": [1, 2, true], "b": null}]', True), + ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True), + # too deep, default unconstrained depth limit = 2 + ( + '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]', + False, + ), + ('[{"a": {"z": {"g": 4}}, "b": null}]', False), + ("[[[[1]]]]", False), + # not an array + ("{}", False), + ('{"a": 1, "b": null}', False), + ('{"a": {"z": {"g": 4}}, "b": null}', False), + ("1234", False), # not an array + ('{"a": "a"}', False), # not an array + ], + ), + # No schema / unconstrained value + ( + {}, + [ + ('"aaabbuecuh"', True), # string + ("5.554", True), # number + ("true", True), # boolean + ("null", True), # null + ("5999", True), # integer + ('["a", "b"]', True), # array + ('{"key": {"k2": "value"}}', True), # nested object + ("this isnt valid json", False), + ], + ), ], ) def test_format_without_regex(schema, examples): @@ -721,7 +787,7 @@ def test_format_without_regex(schema, examples): assert match is None -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"]) +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) def test_json_schema_custom_whitespace_pattern(whitespace_pattern): """assert whitespace_pattern setting respected""" @@ -743,10 +809,32 @@ class MockModel(BaseModel): ) mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - match_default_ws = re.fullmatch(pattern, mock_result_mult_ws) + match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) if whitespace_pattern is None: assert match_default_ws else: - assert match_default_ws is None + assert re.fullmatch(pattern, mock_result_mult_ws) + + +def test_one_of_doesnt_produce_illegal_lookaround(): + """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" + + class Cat(BaseModel): + pet_type: Literal["cat"] + meows: int + + class Dog(BaseModel): + pet_type: Literal["dog"] + barks: float + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") + n: int + + json_schema = Model.schema_json() + + json_schema = Model.schema_json() + pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) - assert re.fullmatch(pattern, mock_result_maybe_ws) + # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() + interegular.parse_pattern(pattern).to_fsm() diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py new file mode 100644 index 000000000..ef7e40eed --- /dev/null +++ b/tests/generate/conftest.py @@ -0,0 +1,24 @@ +from importlib import reload + +import pytest + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index d036b560f..b7eb8b3cb 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -247,3 +247,90 @@ def test_llamacpp_cfg(model): prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) assert isinstance(result, str) + + +@pytest.mark.parametrize( + "repo,model_path,hf_tokenizer_uri", + [ + ("Qwen/Qwen1.5-0.5B-Chat-GGUF", "*q2*.gguf", "Qwen/Qwen1.5-0.5B-Chat"), + ("TheBloke/phi-2-GGUF", "*Q2*.gguf", "microsoft/phi-2"), + ], +) +def test_byte_tokenizer_regression(repo, model_path, hf_tokenizer_uri): + """Reproduce https://github.com/outlines-dev/outlines/issues/820""" + import llama_cpp + + model = models.llamacpp( + repo, + model_path, + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + hf_tokenizer_uri + ), + ) + generator = generate.choice(model, ["skirt", "dress", "pen", "jacket"]) + generator("Pick the odd word out: skirt, dress, pen, jacket") + + +def test_llama_cpp_pre_tokenizer_remains_broken(): + """If fails, llama.cpp pre-tokenizer is fixed -> revert #892, remove `with pytest.raises`""" + repo = "Qwen/Qwen1.5-0.5B-Chat-GGUF" + model_path = "*q2*.gguf" + + model = models.llamacpp(repo, model_path) + with pytest.raises(RuntimeError): + generate.choice(model, ["skirt", "dress", "pen", "jacket"]) + + +def test_RegexGuide_caching(model, temp_cache_dir): + import llama_cpp + + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + generator = generate.regex(model, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.llamacpp( + "Qwen/Qwen1.5-0.5B-Chat-GGUF", + "*q2*.gguf", + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + "Qwen/Qwen1.5-0.5B-Chat" + ), + ) + generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert ( + generator.logits_processor.fsm.states_to_token_maps + != generator_2.logits_processor.fsm.states_to_token_maps + ) + + generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (1, 2) + assert ( + generator_2.logits_processor.fsm.states_to_token_maps + == generator_3.logits_processor.fsm.states_to_token_maps + ) + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 38525a076..da08bed71 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -11,7 +11,7 @@ import outlines.models as models from outlines.fsm.regex import reduced_vocabulary from outlines.models.transformers import Transformers, TransformerTokenizer -from outlines.samplers import beam_search, multinomial +from outlines.samplers import beam_search, greedy, multinomial def test_transformers_integration_text(): @@ -632,3 +632,47 @@ def test_transformers_use_existing_model_and_tokenizer(): model = Transformers(hf_model, hf_tokenizer) sequence = generate.text(model)("Write a short sentence ", rng=rng) assert isinstance(sequence, str) + + +def test_RegexGuide_caching(temp_cache_dir): + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + model = models.transformers( + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + ) + generator = generate.regex(model, regex, sampler=greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + generator_2 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + + generator_3 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (1, 2) + assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b4e410096..f4596a2df 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash(): tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer2 = TransformerTokenizer(tokenizer_hf) - assert tokenizer == tokenizer2 - assert hash(tokenizer) == hash(tokenizer2) + tokenizer_2 = TransformerTokenizer(tokenizer_hf) + + assert tokenizer == tokenizer_2 + assert hash(tokenizer) == hash(tokenizer_2) + + tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") + tokenizer_hf_2.add_tokens(["test_token"]) + + tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) + assert tokenizer != tokenizer_3 + assert hash(tokenizer) != hash(tokenizer_3) diff --git a/tests/test_cache.py b/tests/test_cache.py index 5a2de778e..eb4ec406e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,6 @@ import os import tempfile +import unittest import diskcache import pytest @@ -157,3 +158,33 @@ def foo(): # assert with version upgrade, old cache is invalidated and new cache is used a, b = foo() + + +def test_cache_disabled_decorator(test_cache): + """Ensure cache can be disabled in a local scope""" + + from outlines.caching import cache_disabled + + mock = unittest.mock.MagicMock() + + @test_cache + def fn(): + mock() + return 1 + + # first call isn't cached + fn() + assert mock.call_count == 1 + + # second call doesn't run fn, uses cache + fn() + assert mock.call_count == 1 + + # cache_disabled decorator disables cache within scope + with cache_disabled(): + fn() + assert mock.call_count == 2 # called once in cache_disabled scope + + # scope has exited, cache is enabled again + fn() + assert mock.call_count == 2 diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 65eeb2022..a0433c0e5 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import pytest from pydantic import BaseModel, Field @@ -252,3 +252,63 @@ def source_ppt(model): prompt = source_ppt(response) assert prompt == '{\n "one": "a description",\n "two": ""\n}' + + +def test_prompt_args(): + def no_args(): + pass + + def with_args(x, y, z): + pass + + def with_annotations(x: bool, y: str, z: Dict[int, List[str]]): + pass + + def with_defaults(x=True, y="Hi", z={4: ["I", "love", "outlines"]}): + pass + + def with_annotations_and_defaults( + x: bool = True, + y: str = "Hi", + z: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): + pass + + def with_all( + x1, + y1, + z1, + x2: bool, + y2: str, + z2: Dict[int, List[str]], + x3=True, + y3="Hi", + z3={4: ["I", "love", "outlines"]}, + x4: bool = True, + y4: str = "Hi", + z4: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): + pass + + @outlines.prompt + def args_prompt(fn): + """args: {{ fn | args }}""" + + assert args_prompt(no_args) == "args: " + assert args_prompt(with_args) == "args: x, y, z" + assert ( + args_prompt(with_annotations) + == "args: x: bool, y: str, z: Dict[int, List[str]]" + ) + assert ( + args_prompt(with_defaults) + == "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_annotations_and_defaults) + == "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_all) + == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) diff --git a/tests/test_types.py b/tests/test_types.py index 2391ccc18..5e60348b2 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -10,12 +10,12 @@ @pytest.mark.parametrize( "custom_type,test_string,should_match", [ - (types.PhoneNumber, "12", False), - (types.PhoneNumber, "(123) 123-1234", True), - (types.PhoneNumber, "123-123-1234", True), - (types.ZipCode, "12", False), - (types.ZipCode, "12345", True), - (types.ZipCode, "12345-1234", True), + (types.phone_numbers.USPhoneNumber, "12", False), + (types.phone_numbers.USPhoneNumber, "(123) 123-1234", True), + (types.phone_numbers.USPhoneNumber, "123-123-1234", True), + (types.zip_codes.USZipCode, "12", False), + (types.zip_codes.USZipCode, "12345", True), + (types.zip_codes.USZipCode, "12345-1234", True), (types.ISBN, "ISBN 0-1-2-3-4-5", False), (types.ISBN, "ISBN 978-0-596-52068-7", True), # (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit @@ -24,6 +24,12 @@ (types.ISBN, "9780596520687", True), (types.ISBN, "ISBN-10: 0-596-52068-9", True), (types.ISBN, "0-596-52068-9", True), + (types.Email, "eitan@gmail.com", True), + (types.Email, "99@yahoo.com", True), + (types.Email, "eitan@.gmail.com", False), + (types.Email, "myemail", False), + (types.Email, "eitan@gmail", False), + (types.Email, "eitan@my.custom.domain", True), ], ) def test_type_regex(custom_type, test_string, should_match): @@ -42,6 +48,27 @@ class Model(BaseModel): assert does_match is should_match +def test_locale_not_implemented(): + with pytest.raises(NotImplementedError): + types.locale("fr") + + +@pytest.mark.parametrize( + "locale_str,base_types,locale_types", + [ + ( + "us", + ["ZipCode", "PhoneNumber"], + [types.zip_codes.USZipCode, types.phone_numbers.USPhoneNumber], + ) + ], +) +def test_locale(locale_str, base_types, locale_types): + for base_type, locale_type in zip(base_types, locale_types): + type = getattr(types.locale(locale_str), base_type) + assert type == locale_type + + @pytest.mark.parametrize( "custom_type,test_string,should_match", [