Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a trie for scanning during index construction #507

Closed
wants to merge 2 commits into from

Conversation

lapp0
Copy link
Contributor

@lapp0 lapp0 commented Jan 6, 2024

Update state_scan_tokens to use a trie. We are performing an expensive _walk_fsm for every token * every state, which is the current bottleneck. This reduces the calls to _walk_fsm.

Performance

Author of interegular pushed my interegular FSM.reduce() optimization to pypi in version 0.3.3.

(For simple regex, interegular reduce() doesn't take much time, so the difference between 0.3.3 and 0.3.2 is mostly variation in runtime of the rest of RegexFSM.)

  • main (0.3.2): Performance of main on old interegular
  • main (0.3.3): Performance with interegular change
  • PR (wo/ trie): Performance with interegular change + PR's numbafication
  • PR: Performance with interegular change + PR's numbafication + PR's token trie

Benchmarks indicate all RegexFSM constructions are faster than before. Worst sample is 20% faster, best sample is 700% faster.

Benchmarks

email:

[a-z0-9!#$%&'*+/=?^_{|}~-]+(?:.[a-z0-9!#$%&'*+/=?^_{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?
  • main (0.3.2): 0.5830576419830322
  • main (0.3.3): 0.58341383934021
  • PR (wo/ trie): 0.34377002716064453
  • PR: 0.47698211669921875 (bad)

complex_phone:

\+?\d{1,4}?[-.\s]?\(?\d{1,3}?\)?[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}
  • main (0.3.2): 1.5468947887420654
  • main (0.3.3): 1.4503862857818604
  • PR (wo/ trie): 1.1407146453857422
  • PR: 0.32466745376586914 (good)

simple_phone:

\+?[1-9][0-9]{7,14}
  • main (0.3.2): 0.4189467430114746
  • main (0.3.3): 0.4271810054779053
  • PR (wo/ trie): 0.3233828544616699
  • PR: 0.14258384704589844 (good)

date:

([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])|([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])
  • main (0.3.2): 0.8260455131530762
  • main (0.3.3): 0.8374199867248535
  • PR (wo/ trie): 0.6275467872619629
  • PR: 0.12444806098937988 (good)

time:

(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?
  • main (0.3.2): 0.26282405853271484
  • main (0.3.3): 0.28090667724609375
  • PR (wo/ trie): 0.19170904159545898
  • PR: 0.10238885879516602 (good)

ip:

(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)
  • main (0.3.2): 0.5932526588439941
  • main (0.3.3): 0.5779187679290771
  • PR (wo/ trie): 0.4326210021972656
  • PR: 0.11676287651062012 (good)

url:

(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?
  • main (0.3.2): 0.9047980308532715
  • main (0.3.3): 0.7660520076751709
  • PR (wo/ trie): 0.5781815052032471
  • PR: 0.7124283313751221 (bad)

ssn:

\d{3}-\d{2}-\d{4}
  • main (0.3.2): 0.306537389755249
  • main (0.3.3): 0.34891176223754883
  • PR (wo/ trie): 0.22051334381103516
  • PR: 0.10631608963012695 (good)

very_complex:

  • main (0.3.2): 24.344436407089233
  • main (0.3.3): 11.929908275604248
  • PR (wo/ trie): 8.78506064414978
  • PR: 3.1125617027282715 (good)

json_schema:

\{[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.){,10}"[\n ]*,[\n ]*"age"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*,[\n ]*"armor"[\n ]*:[\n ]*("leather"|"chainmail"|"plate")[\n ]*,[\n ]*"strength"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*\}
  • main (0.3.2): 4.285711050033569
  • main (0.3.3): 4.1461546421051025
  • PR (wo/ trie): 3.075744867324829
  • PR: 2.2587642669677734 (good)

complex_json_schema:

\{[\n ]*"id"[\n ]*:[\n ]*(-)?((0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[\n ]*,[\n ]*"work"[\n ]*:[\n ]*\{[\n ]*"id"[\n ]*:[\n ]*(-)?((0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[\n ]*,[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"composer"[\n ]*:[\n ]*\{[\n ]*"id"[\n ]*:[\n ]*(-)?((0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[\n ]*,[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"functions"[\n ]*:[\n ]*\[("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")(,("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))*\][\n ]*\}[\n ]*\}[\n ]*,[\n ]*"recording_artists"[\n ]*:[\n ]*\[(\{[\n ]*"id"[\n ]*:[\n ]*(-)?((0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[\n ]*,[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"functions"[\n ]*:[\n ]*\[("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")(,("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))*\][\n ]*\})(,(\{[\n ]*"id"[\n ]*:[\n ]*(-)?((0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[\n ]*,[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"functions"[\n ]*:[\n ]*\[("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")(,("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))*\][\n ]*\}))*\][\n ]*\}
  • main (0.3.2): 9.537730693817139
  • main (0.3.3): 8.472238540649414
  • PR (wo/ trie): 5.008383274078369
  • PR: 2.342341661453247 (good)

Benchmark code:

import cProfile
import pstats

import time
from outlines.models.transformers import TransformerTokenizer
from outlines.fsm.json_schema import build_regex_from_object

schema = """{
        "$defs": {
            "Armor": {
                "enum": ["leather", "chainmail", "plate"],
                "title": "Armor",
                "type": "string"
            }
        },
        "properties": {
            "name": {"maxLength": 10, "title": "Name", "type": "string"},
            "age": {"title": "Age", "type": "integer"},
            "armor": {"$ref": "#/$defs/Armor"},
            "strength": {"title": "Strength", "type": "integer"}\
        },
        "required": ["name", "age", "armor", "strength"],
        "title": "Character",
        "type": "object"
    }"""

complex_schema = """{
  "$schema": "http://json-schema.org/draft-04/schema#",
  "title": "Schema for a recording",
  "type": "object",
  "definitions": {
    "artist": {
      "type": "object",
      "properties": {
        "id": {"type": "number"},
        "name": {"type": "string"},
        "functions": {
          "type": "array",
          "items": {"type": "string"}
        }
      },
      "required": ["id", "name", "functions"]
    }
  },
  "properties": {
    "id": {"type": "number"},
    "work": {
      "type": "object",
      "properties": {
        "id": {"type": "number"},
        "name": {"type": "string"},
        "composer": {"$ref": "#/definitions/artist"}
      }
    },
    "recording_artists": {
      "type": "array",
      "items": {"$ref": "#/definitions/artist"}
    }
  },
  "required": ["id", "work", "recording_artists"]
}"""

regex_samples = {
    "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
    "complex_phone": "\\+?\\d{1,4}?[-.\\s]?\\(?\\d{1,3}?\\)?[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,9}",
    "simple_phone": "\\+?[1-9][0-9]{7,14}",
    "date": "([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])|([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])",
    "time": r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?",
    "ip": r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)",
    "url": r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
    "ssn": r"\d{3}-\d{2}-\d{4}",
    "very_complex": <redacted>,
    "json_schema": build_regex_from_object(schema),
    "complex_json_schema": build_regex_from_object(complex_schema),
}

from outlines.fsm.fsm import RegexFSM

tokenizer = TransformerTokenizer("gpt2")


t = time.time()
RegexFSM(regex_string='ab', tokenizer=tokenizer)
print(time.time() - t, "seconds for a simple initial regex which compiles numba functions")




def create_rfsm(r):
    rfsm = RegexFSM(regex_string=r, tokenizer=tokenizer)


cProfile.run(f'create_rfsm(regex_samples["very_complex"])', 'profile_stats')


for name, r in regex_samples.items():
    print()
    print(f"{name}: `{r}`")
    print()
    start = time.time()
    create_rfsm(r)
    print("-", time.time() - start)


p = pstats.Stats('profile_stats')
p.sort_stats('cumtime').print_stats()

@brandonwillard brandonwillard marked this pull request as draft January 6, 2024 15:01
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need profiling measurements to see which kinds of improvements the Numba changes introduce.

pyproject.toml Outdated Show resolved Hide resolved
outlines/fsm/regex.py Show resolved Hide resolved
@brandonwillard
Copy link
Member

Not done, but good to have: state_scan_tokens can be improved substantially by using a trie. We are performing an expensive _walk_fsm for every token * every state, which is the current bottleneck.

_walk_fsm should be very cheap. You can try setting up a trie for this, but we'll need profile comparisons to accept the changes.

@rlouf
Copy link
Member

rlouf commented Jan 6, 2024

Great! We would need to profile the code in main and the code in this PR to get an idea of the performance gains.

@lapp0
Copy link
Contributor Author

lapp0 commented Jan 6, 2024

Not done, but good to have: state_scan_tokens can be improved substantially by using a trie. We are performing an expensive _walk_fsm for every token * every state, which is the current bottleneck.

_walk_fsm should be very cheap. You can try setting up a trie for this, but we'll need profile comparisons to accept the changes.

_walk_fsm is cheap, but calling it 17,500,000 for an FSM with 350 potential states and a vocab of size 50,000 isn't. For large fsms these calls are the bottleneck. A trie may reduce the size of the checked vocabulary substantially, depending on the regex

@brandonwillard
Copy link
Member

Not done, but good to have: state_scan_tokens can be improved substantially by using a trie. We are performing an expensive _walk_fsm for every token * every state, which is the current bottleneck.

_walk_fsm should be very cheap. You can try setting up a trie for this, but we'll need profile comparisons to accept the changes.

_walk_fsm is cheap, but calling it 17,500,000 for an FSM with 350 potential states and a vocab of size 50,000 isn't. For large fsms these calls are the bottleneck. A trie may reduce the size of the checked vocabulary substantially, depending on the regex

Feel free to open a separate issue and/or PR for that idea.

@lapp0

This comment was marked as duplicate.

@brandonwillard
Copy link
Member

Profile results indicate a marginal improvement (10 - 20% time reduction) for simple patterns, and a larger improvement (40% - 60% time reduction) for complex patterns.

Thanks for the results! Can you determine the performance gains for each change (i.e. the interegular change and the Numba one) separately? At the very least, we need the results of the Numba changes separate from the interegular ones.

@lapp0
Copy link
Contributor Author

lapp0 commented Jan 7, 2024

Profile results indicate a marginal improvement (10 - 20% time reduction) for simple patterns, and a larger improvement (40% - 60% time reduction) for complex patterns.

Thanks for the results! Can you determine the performance gains for each change (i.e. the interegular change and the Numba one) separately? At the very least, we need the results of the Numba changes separate from the interegular ones.

Please see updated description.

@lapp0
Copy link
Contributor Author

lapp0 commented Jan 7, 2024

Use of a token trie resulted in walk_fsm being called 5,132 times for a 9 state regex ((0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?).

Previously it was called 449,307 times (vocab size is 49,923)

Updated the description with these benchmarks. A trie is beneficial except when the regex is accepting of many tokens in many states.

@lapp0 lapp0 marked this pull request as ready for review January 8, 2024 01:26
@rlouf
Copy link
Member

rlouf commented Jan 8, 2024

Before I thoroughly review the PR:

  1. Could you fix the failing tests?
  2. Could you rebase your commits so they each correspond to one logical change? It will make review easier.
  3. The initial compilation time is important, as everything else can be eventually cached. I don't see any measurements of those?

@rlouf
Copy link
Member

rlouf commented Jan 11, 2024

I think the best way forward here is to separate the trie and the Numba update to make conversations and evaluation easier. We might also need to add a benchmarking suite to track these changes over time in the library.

@lapp0 lapp0 force-pushed the speed-up-regexfsm branch 2 times, most recently from 903a084 to ced4bc1 Compare January 11, 2024 11:21
@lapp0
Copy link
Contributor Author

lapp0 commented Jan 11, 2024

1. Could you fix the failing tests?

Done

2. Could you rebase your commits so they each correspond to one logical change? It will make review easier.

Done

3. The initial compilation time is important, as everything else can be eventually cached. I don't see any measurements of those?
def create_rfsm(r):
    rfsm = RegexFSM(regex_string=r, tokenizer=tokenizer)

start = time.time()
create_rfsm("a")
print("compile time:", time.time() - start)

Ran "first-second-run" experiment twice on each branch, totalling 8 runs of the script above.

PR:

  • First run: (17.3140127658844, 17.043684244155884)
  • Second run after numba caches: (7.758531332015991, 7.461982488632202)

Master:

  • First run: (12.253467321395874, 12.760387420654297)
  • Second run after numba caches: (7.269283056259155, 7.479955196380615)

@rlouf
Copy link
Member

rlouf commented Jan 11, 2024

I'm afraid we cannot merge something that makes first-time compilation slower. We can always make the following calls faster by caching the index construction, as is now the case on main.

As above I suggest breaking down this PR in two PRs: one for the Numba change and one for the trie so we can evaluate them independently. Unless you know which change is responsible for the slow down, in which case I would advise to reverse it and benchmark with the other one.

@lapp0
Copy link
Contributor Author

lapp0 commented Jan 11, 2024

I'm afraid we cannot merge something that makes first-time compilation slower. We can always make the following calls faster by caching the index construction, as is now the case on main.

As above I suggest breaking down this PR in two PRs: one for the Numba change and one for the trie so we can evaluate them independently. Unless you know which change is responsible for the slow down, in which case I would advise to reverse it and benchmark with the other one.

Any thoughts on AOT compilation? Only down-side I see is if distributed by Docker it would bind the library to a specific architecture.

@rlouf
Copy link
Member

rlouf commented Jan 11, 2024

Let's split the PR before discussing approach-specific optimization.

@brandonwillard
Copy link
Member

Any thoughts on AOT compilation? Only down-side I see is if distributed by Docker it would bind the library to a specific architecture.

A long-term viable AOT compilation approach in Numba isn't very clear right now. They're in the process of deprecating the current approach, and I believe the replacement is going to be PIXIE.

@lapp0
Copy link
Contributor Author

lapp0 commented Jan 13, 2024

I'm going to create a PR with some benchmarks using pytest-benchmark, then I'll make this two separate PRs.

@lapp0 lapp0 closed this Jan 13, 2024
@brandonwillard brandonwillard changed the title Speed up RegexFSM Use a trie for scanning during index construction Apr 20, 2024
@brandonwillard brandonwillard added enhancement optimization Related to performance optimizations structured generation Linked to structured generation numba labels Apr 20, 2024
@brandonwillard
Copy link
Member

It seems reasonable to consider this again after #768 (and even more so when Numba allow us to serialize the typed collections more easily).

Looks like this branch might need to be updated to account for the changes introduced by #738, though.

@brandonwillard brandonwillard marked this pull request as draft April 20, 2024 23:30
@brandonwillard brandonwillard linked an issue Apr 20, 2024 that may be closed by this pull request
@lapp0
Copy link
Contributor Author

lapp0 commented Apr 21, 2024

The benchmarks from #542 indicate this would result in compilation times ranging from a 30% increase to a 70% decrease, on average decreasing the time. It's especially beneficial for complex FSMs with many states.

Can't rebase at the moment, but happy to review a rebase, please ping me for PR.

@lapp0
Copy link
Contributor Author

lapp0 commented May 10, 2024

Closing in favor of #887

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement numba optimization Related to performance optimizations structured generation Linked to structured generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants