Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLLM tensor-parallel and RegexLogitsProcessor #524

Closed
BenoitHardier opened this issue Jan 11, 2024 · 25 comments
Closed

VLLM tensor-parallel and RegexLogitsProcessor #524

BenoitHardier opened this issue Jan 11, 2024 · 25 comments
Labels
bug vLLM Things involving vLLM support

Comments

@BenoitHardier
Copy link

Describe the issue as clearly as possible:

Hi,
I recently tried to use the RegexLogitsProcessor with VLLM introduced by #481.

When using it with a "small" model like 7B one on a unique GPU it works fine but when I try with a big one, namely Mixtral, on multiple GPUs with the vllm engine argument tensor-parallel, I ran into several problems (monkey patching not working and fsm_state in the Processor not initialize). I suspect the multiple workers of Ray to be the cause (monkey patching may not be propagated to all the workers same for the fsm_states).

I could have missed some relevant information but It seems that #481 only checks without tensor-parallel.

Steps/code to reproduce the bug:

import vllm
import vllm.model_executor.layers.sampler as sampler
from pydantic import BaseModel

from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors

# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors


class User(BaseModel):
    id: int
    name: str

model = "mistralai/Mixtral-8X7B-Instruct-v0.1"
#model = "mistralai/Mistral-7B-Instruct-v0.2"
llm = vllm.LLM(model=model, dtype='float16', max_model_len=1024,  tensor_parallel_size=4, max_num_seqs=512, enforce_eager=True)
logits_processor = JSONLogitsProcessor(User, llm)
result = llm.generate(
    ["A prompt", "Another prompt"],
    sampling_params=vllm.SamplingParams(
        max_tokens=100, logits_processors=[logits_processor]
    ),
)
print(result)

Expected result:

result from vllm

Error message:

Traceback (most recent call last):
  File "/workspace/./python_scripts/vllm_integration.py", line 19, in <module>
    result = llm.generate(
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 165, in generate
    return self._run_engine(use_tqdm)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 185, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 581, in step
    output = self._run_workers(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 755, in _run_workers
    self._run_workers_in_batch(workers, method, *args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 732, in _run_workers_in_batch
    all_outputs = ray.get(all_outputs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2624, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): ray::RayWorkerVllm.execute_method() (pid=250471, ip=172.17.0.12, actor_id=6003ef308c0e22f0e05d73c901000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7f8c9c37a770>)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/ray_utils.py", line 31, in execute_method
    return executor(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 159, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 354, in execute_model
    output = self.model.sample(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mixtral.py", line 390, in sample
    next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 52, in forward
    logits = _apply_logits_processors(logits, sampling_metadata)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 172, in _apply_logits_processors
    logits_row = logits_processor(token_ids, logits_row)
TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'

Outlines/Python version information:

Outlines 0.0.22
Python 3.10.12

Context for the issue:

No response

@viktor-ferenczi
Copy link

I have the same issue.

Check the changes in this vLLM PR (not merged yet):
vllm-project/vllm#2105

It integrates grammar into vLLM itself and contains code which works with Ray, which may give some ideas how to solve this.

This is the specific code:

    if request.grammar:
        if engine.worker_use_ray:
            grammar_logits_processor = RayRemoteGrammarLogitsProcessor(
                tokenizer=tokenizer, grammar=request.grammar)
        else:
            grammar_logits_processor = GrammarLogitsProcessor(
                tokenizer=tokenizer, grammar=request.grammar)
        logits_processors = [grammar_logits_processor]
    else:
        logits_processors = []

I guess we need a class similar to RayRemoteGrammarLogitsProcessor.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 14, 2024

I'm considering closing the vllm PR and moving the work over here since outlines has a more mature, fleshed out implementation.

I'm seeing a substantial performance degradation with RayRemoteGrammarLogitsProcessor. I haven't profiled, but my guess is that sending a tensor over ray is not performant.

I will experiment with the actor only taking the tokens as inputs, and returning a boolean mask. This would prevent the tensor going back and forth through ray.

https://numpy.org/doc/stable/reference/generated/numpy.ma.MaskedArray.tobytes.html

@viktor-ferenczi
Copy link

I agree with closing the vLLM ticket. I think vLLM needs another, much smaller PR. That should include only small changes to pass enough information to outlines, so no tricky patches are needed. Namely vLLM needs to pass the seq_id to the logits processors. It is an API change, however, but warranted.

Regarding Ray it seems to be a good idea, indeed. Try to pass numpy arrays, it may be faster. Also, the tokens may be largely redundant, so sending only the difference may also help decreasing the amount of information sent between processes.

@rlouf
Copy link
Member

rlouf commented Jan 15, 2024

I agree with closing the vLLM ticket. I think vLLM needs another, much smaller PR. That should include only small changes to pass enough information to outlines, so no tricky patches are needed. Namely vLLM needs to pass the seq_id to the logits processors. It is an API change, however, but warranted.

We can think of other ways to track the FSM state, like using token_ids as keys as well. Someone mentioned it earlier but I can't find the comment.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 15, 2024

Regarding Ray it seems to be a good idea, indeed. Try to pass numpy arrays, it may be faster. Also, the tokens may be largely redundant, so sending only the difference may also help decreasing the amount of information sent between processes.

Logit processors aren't associated with one specific sequence so this is challenging.

_patched_apply_logits_processors does not work for generating multiple sequences at once because seq_id doesn't necessarily map to the same sequence between generations.

However patching also isn't necessary if you have correct cached-state lookups.

It's necessary to send a representation of the token sequence. Either the minimal (hash(prev_token_id), last_token_id) or just token_ids needs to be sent to the logits processing actor so it can look up the state.

The actor can send back bytes form of the mask.

We can think of other ways to track the FSM state, like using token_ids as keys as well. Someone mentioned it earlier but I can't find the comment.

#416 (comment)

This is necessary to get concurrent generation within a sequence group in vLLM working. Otherwise the generations tokens won't necessarily be added to the correct sequence. I got beam search working with outlines by using the token_ids as the FSM state key.

@viktor-ferenczi
Copy link

It explains why my attempt to use beam search with outlines regex failed yesterday. I like the approach of taking a hash of previous tokens.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 15, 2024

@viktor-ferenczi could you try this PR and tell me if it also works for you?

#539

@viktor-ferenczi
Copy link

Sure, I will try this tomorrow.

@viktor-ferenczi
Copy link

Tested both in regex and json_schema modes

Failed with errors:

AttributeError: 'RegexLogitsProcessor' object has no attribute 'fsm_state'
AttributeError: 'JSONLogitsProcessor' object has no attribute 'fsm_state'

Full tracebacks for regex and json_schema modes:
error.txt

The cause looks like that self.fsm_state is not set in RegexLogitsProcessor.__init__. This member is then accessed when the logits processor is first called, but without initializing it.

Appending this line to RegexLogitsProcessor.__init__ fixes the crash:

self.fsm_state: DefaultDict[int, int] = defaultdict(int)

But the constraint does not work as expected, it generates only 1-2 characters then stops.

GPUs: 2x4090 (2x24GB)

vLLM command:

python -O -u -m outlines.serve.serve \
  --model=TheBloke/deepseek-coder-33B-instruct-AWQ \
  --quantization=awq \
  --dtype=float16 \
  --host=0.0.0.0 \
  --port=8000 \
  --max-model-len=16384 \
  --max-num-seqs=16 \
  --tensor-parallel-size=2 \
  --swap-space=8 \
  --gpu-memory-utilization=0.95 \
  --enforce-eager \
  --disable-log-requests

vLLM request:

{
  "prompt": "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list.\n\n### Response:\n",
  "n": 1,
  "best_of": 1,
  "presence_penalty": 0.0,
  "frequency_penalty": 0.0,
  "repetition_penalty": 1.0,
  "temperature": 0.0,
  "top_p": 1.0,
  "top_k": -1,
  "min_p": 0.0,
  "use_beam_search": false,
  "length_penalty": 1.0,
  "early_stopping": false,
  "stop": [],
  "stop_token_ids": [],
  "include_stop_str_in_output": false,
  "ignore_eos": false,
  "max_tokens": 50,
  "logprobs": null,
  "prompt_logprobs": null,
  "skip_special_tokens": true,
  "spaces_between_special_tokens": true,
  "regex": "\\d+(\\s*,\\s*\\d+)*\\s*"
}

Response:

{
  "text": [
    "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list.\n\n### Response:\n1,"
  ]
}

@lapp0
Copy link
Collaborator

lapp0 commented Jan 17, 2024

Thanks for running, I'll write some more test cases, I think it got messed up on a rebase, my mistake.

I believe the early termination is a result of the issue being fixed in #544

@viktor-ferenczi
Copy link

Thank you, please let me know when I can test it again. I will try to cherry-pick the early termination fix as well.

@viktor-ferenczi
Copy link

viktor-ferenczi commented Jan 17, 2024

@lapp0 Please find the fixed RegexLogitsProcessor class in my comment on your PR (#539), that's tested to work, even without the early termination fix.

I keep the working code in my dev branch here, so you can cherry-pick the fix from there as well: https://github.com/viktor-ferenczi/outlines/tree/dev

UPDATE: Tested OK (with my fix) for multiple sequence generation and beam search as well. All seems to work well.

@rlouf rlouf added the vLLM Things involving vLLM support label Jan 18, 2024
@rlouf
Copy link
Member

rlouf commented Jan 18, 2024

Do we expect this issue to be solved by #539? In this case we should link it to the PR.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 19, 2024

Do we expect this issue to be solved by #539? In this case we should link it to the PR.

Nope, that PR just prevents corruption when generating multiple sequences concurrently in vLLM. Tensor parallel doesn't work with that PR.

@viktor-ferenczi
Copy link

I actually use the changes from #539 with minor modifications (see there) with --tensor-parallel=2 and it works.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 19, 2024

@viktor-ferenczi Could you test how many tokens are parsed on each step with your recursive solution? My concern is that the entire generated sequence is being re-parsed at each step if you don't have a shared cache between actors.

Regarding using a ray actor,

I did some performance testing:

  • round trip time for a (1, 2**16) tensor (logits) being sent to a ray actor and received from a ray actor is ~0.005s
  • time for a sent 4096 integer list (generated sequence) and a returned 2**16 bool mask is ~0.0015s
  • Moving logits from GPU to CPU to ray actor and back to GPU again adds negligible overhead

https://gist.github.com/lapp0/409a7c3a7f9880b606626bb283f0b01c

@viktor-ferenczi
Copy link

Sure, I can test it after work if I have any energy left.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 19, 2024

@viktor-ferenczi can confirm, the recursive solution re-parses the entire sequence from start every time

https://gist.github.com/lapp0/8370bde4d977088487c34bc7501b78af

We should ensure the cache is saved in a stable object to address tensor parallel.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 19, 2024

@simon-mo do you have any recommendations for the problem we're facing here?

We maintain the parser state within the logits processor. Without making any adjustments, the logits processor object will replicate and have inconsistent state between ray processes. In the vllm grammars PR I resolved this by creating a ray actor, but the implementation had some issues with performance.

Do you think putting wrapping the logits processor with a ray actor is the appropriate solution, or would you suggest an alternative?

@viktor-ferenczi
Copy link

viktor-ferenczi commented Jan 21, 2024

@viktor-ferenczi can confirm, the recursive solution re-parses the entire sequence from start every time

https://gist.github.com/lapp0/8370bde4d977088487c34bc7501b78af

We should ensure the cache is saved in a stable object to address tensor parallel.

@lapp0 Please check out the code in my dev branch (pushed a new one):

https://github.com/viktor-ferenczi/outlines/commits/dev/

It contains the changes from your branch, then a solution for caching the RegexFSM objects by regex_string and caching the states separately for each of them. There is no cleanup of unused cache entries yet, that will be implemented if/when the general caching solution is approved.

Tested OK for single and parallel regex and JSON schema constraints. There seem to be negligible performance impact, so the caching seems to be efficient and persistent across sequences.

Also double-checked that the logits processor is always called from the same process and the same thread, even if tensor parallel is greater than one, therefore there is not need to care about the thread safety of the cache.

Please let me know whether this solution works for you. I start to use it myself to test it more.

@lapp0
Copy link
Collaborator

lapp0 commented Jan 21, 2024

@viktor-ferenczi

Running benchmarks on CFGFSM, performance of your branch is nearly equal (actually slightly better), and substantially better than the recursive solution.

The cost is one hash -> int cache entry for every token generated. so we definitely need an eviction policy. We also should test large generations as well, the below performance tests have a max of ~2700 tokens.

Thanks for working this out! I'll work on integrating your changes into the PR.

My branch (beam search working, tensor parallel not working)

Additional Benchmark Details:
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[json-True]:
	Tokens / Second: 113.506
	(Num Tokens: 2706, Time: 23.840 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[json-False]:
	Tokens / Second: 103.270
	(Num Tokens: 2706, Time: 26.203 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[csv-True]:
	Tokens / Second: 15.249
	(Num Tokens: 79, Time: 5.181 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[csv-False]:
	Tokens / Second: 10.869
	(Num Tokens: 79, Time: 7.268 seconds)


-------------------------------------------------------------------------------------------- benchmark: 4 tests --------------------------------------------------------------------------------------------
Name (time in s)                                  Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_cfg_generation[csv-True]        5.1808 (1.0)       5.1808 (1.0)       5.1808 (1.0)      0.0000 (1.0)       5.1808 (1.0)      0.0000 (1.0)           0;0  0.1930 (1.0)           1           1
test_benchmark_cfg_generation[csv-False]       7.2682 (1.40)      7.2682 (1.40)      7.2682 (1.40)     0.0000 (1.0)       7.2682 (1.40)     0.0000 (1.0)           0;0  0.1376 (0.71)          1           1
test_benchmark_cfg_generation[json-True]      23.8402 (4.60)     23.8402 (4.60)     23.8402 (4.60)     0.0000 (1.0)      23.8402 (4.60)     0.0000 (1.0)           0;0  0.0419 (0.22)          1           1
test_benchmark_cfg_generation[json-False]     26.2031 (5.06)     26.2031 (5.06)     26.2031 (5.06)     0.0000 (1.0)      26.2031 (5.06)     0.0000 (1.0)           0;0  0.0382 (0.20)          1           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Viktor's branch (beam search AND tensor parallel working)

tests/benchmark/test_benchmark_cfg_generation.py ....                                                                                                                            [100%]

Additional Benchmark Details:
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[json-True]:
	Tokens / Second: 116.989
	(Num Tokens: 2706, Time: 23.130 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[json-False]:
	Tokens / Second: 107.233
	(Num Tokens: 2706, Time: 25.235 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[csv-True]:
	Tokens / Second: 15.780
	(Num Tokens: 79, Time: 5.006 seconds)
tests/benchmark/test_benchmark_cfg_generation.py::test_benchmark_cfg_generation[csv-False]:
	Tokens / Second: 11.234
	(Num Tokens: 79, Time: 7.032 seconds)


-------------------------------------------------------------------------------------------- benchmark: 4 tests --------------------------------------------------------------------------------------------
Name (time in s)                                  Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_cfg_generation[csv-True]        5.0064 (1.0)       5.0064 (1.0)       5.0064 (1.0)      0.0000 (1.0)       5.0064 (1.0)      0.0000 (1.0)           0;0  0.1997 (1.0)           1           1
test_benchmark_cfg_generation[csv-False]       7.0320 (1.40)      7.0320 (1.40)      7.0320 (1.40)     0.0000 (1.0)       7.0320 (1.40)     0.0000 (1.0)           0;0  0.1422 (0.71)          1           1
test_benchmark_cfg_generation[json-True]      23.1304 (4.62)     23.1304 (4.62)     23.1304 (4.62)     0.0000 (1.0)      23.1304 (4.62)     0.0000 (1.0)           0;0  0.0432 (0.22)          1           1
test_benchmark_cfg_generation[json-False]     25.2348 (5.04)     25.2348 (5.04)     25.2348 (5.04)     0.0000 (1.0)      25.2348 (5.04)     0.0000 (1.0)           0;0  0.0396 (0.20)          1           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@simon-mo
Copy link

the logits processor object will replicate and have inconsistent state between ray processes

@lapp0 in vLLM I recall the sampling procedure is only done in a single process (driver process). I'm quite confused why would it happen.

@wdhitchc
Copy link

Has or is this close to being resolved?

I am getting TypeError: RegexLogitsProcessor.call() missing 1 required positional argument: 'scores' error still

@lapp0
Copy link
Collaborator

lapp0 commented Jan 26, 2024

@wdhitchc we just got some important prerequisites merged, I hope to have it ready for review this weekend.

@lapp0
Copy link
Collaborator

lapp0 commented Feb 6, 2024

@wdhitchc @BenoitHardier my smoke tests suggest main works for tensor parallel. Could you please try again and let me know if you run into any issues (reporting your success also helps!)

rlouf pushed a commit that referenced this issue Feb 14, 2024
For `outlines/vllm` previously FSM-sequence correspondence was broken,
resulting FSM state being mixed between sequences, corrupting output. To
alleviate this, we have `_patched_apply_logits_processor` which passes a
stable sequence ID to the logits processor.

In this PR we eliminate `_patched_apply_logits_processor` and cache FSM
state based on the states input IDs.

Continuation of #539 but
much simpler because vllm upgrade fixed a lot of the issues being
addressed there.

Related discussions:
- #624

Fixes:
- Fixes #605 
- Fixes #610

Already fixed:
- #524 (this one can be
closed, as it's was addressed previously by upgrading vllm)


@viktor-ferenczi can you please confirm whether this branch fixes either
#610 or
#605

# Smoke tests

### basic parallel

passed

<details>

```
import json
import vllm
from pydantic import BaseModel
from typing import List
import torch
import pandas as pd
from outlines.serve.vllm import JSONLogitsProcessor

class ConceptsList(BaseModel):
    concepts: List[str]

BASE_MODEL = "microsoft/phi-2"
llm = vllm.LLM(model=BASE_MODEL, tensor_parallel_size=1, dtype=torch.float16, max_model_len=2048)

logits_processor = JSONLogitsProcessor(ConceptsList, llm.llm_engine)

full_prompts = [
    f"Provide me a list of {i} strings with key 'concepts'"
    for i in range(20)
]

batch_results = llm.generate(
    full_prompts,
    sampling_params=vllm.SamplingParams(
        max_tokens=2048, logits_processors=[logits_processor]
    ),
)


for result in batch_results:
    for output in result.outputs:
            json.loads(output.text)
```

</details>


### never ending regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
        "prompt": "Sequence of numbers and letters:",
        "regex": "([123]-[abc]-([def]-)?)*",
        "n": 7
}'
{"text":["Sequence of numbers and letters:1-a-1-b-1-c-1-a-","Sequence of numbers and letters:1-a-2-b-3-c-1-a-","Sequence of numbers and letters:1-a-2-b-3-c-d-1-","Sequence of numbers and letters:2-a-1-b-2-c-1-b-","Sequence of numbers and letters:2-b-3-c-d-2-b-3-","Sequence of numbers and letters:2-a-3-b-2-b-1-c-","Sequence of numbers and letters:2-a-3-b-d-2-a-3-"]}


# rules for the above to validate correct FSM-sequence correspondence:
# [123] always followed by [abc], [def] only ever preceded by [abc]

# 1-a-1-b-1-c-1-a-
# 1-a-2-b-3-c-1-a-
# 1-a-2-b-3-c-d-1-
# 2-a-1-b-2-c-1-b-
# 2-b-3-c-d-2-b-3-
# 2-a-3-b-2-b-1-c-
# 2-a-3-b-d-2-a-3-
```

</details>


### sometimes ending early regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
        "prompt": "Sequence of numbers and letters:",
        "regex": "([123]-[abc]-([def]-)?){3}",
        "n": 16
}'
```

output

```
{"text":["Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:3-a-1-b-2-c-d-","Sequence of numbers and letters:2-a-1-b-3-c-d-","Sequence of numbers and letters:1-a-1-b-1-c-d-","Sequence of numbers and letters:2-a-3-b-d-1-c-e-","Sequence of numbers and letters:1-b-3-a-2-c-d-","Sequence of numbers and letters:3-a-d-1-b-e-2-c-","Sequence of numbers and letters:1-a-3-b-1-b-d-","Sequence of numbers and letters:3-a-f-2-b-d-1-c-","Sequence of numbers and letters:1-b-d-3-a-e-2-c-","Sequence of numbers and letters:3-c-1-b-d-1-a-e-","Sequence of numbers and letters:1-c-1-c-e-1-b-e-"]}
```

analysis:

```
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
3-a-1-b-2-c-d-
2-a-1-b-3-c-d-
1-a-1-b-1-c-d-
2-a-3-b-d-1-c-e-
1-b-3-a-2-c-d-
3-a-d-1-b-e-2-c-
1-a-3-b-1-b-d-
3-a-f-2-b-d-1-c-
1-b-d-3-a-e-2-c-
3-c-1-b-d-1-a-e-
1-c-1-c-e-1-b-e-
```

Observations:
- All patterns are correct
- Patterns don't "borrow" FSM state from one-another, they retain their
own independent state
- Some patterns produced more tokens than others successfully


</details>

### Viktor's regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
  "prompt": "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list, starting with 2.\n\n### Response:\n",
  "n": 1,
  "best_of": 1,
  "presence_penalty": 0.0,
  "frequency_penalty": 0.0,
  "repetition_penalty": 1.0,
  "temperature": 0.0,
  "top_p": 1.0,
  "top_k": -1,
  "min_p": 0.0,
  "use_beam_search": false,
  "length_penalty": 1.0,
  "early_stopping": false,
  "stop": [],
  "stop_token_ids": [],
  "include_stop_str_in_output": false,
  "ignore_eos": false,
  "max_tokens": 50,
  "logprobs": null,
  "prompt_logprobs": null,
  "skip_special_tokens": true,
  "spaces_between_special_tokens": true,
  "regex": "\\d+(\\s*,\\s*\\d+)*\\s*"
}'
```

output:

```
{"text":["You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list, starting with 2.\n\n### Response:\n2, 3, 5, 7, 11, 13, 17, 19, 23, 29\n"]}
```

</details>

### Viktors schema

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
  "prompt": "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n",
  "n": 5,
  "best_of": 5,
  "presence_penalty": 0.0,
  "frequency_penalty": 0.0,
  "repetition_penalty": 1.0,
  "temperature": 1.0,
  "top_p": 1.0,
  "top_k": -1,
  "min_p": 0.0,
  "use_beam_search": false,
  "length_penalty": 1.0,
  "early_stopping": false,
  "stop": [],
  "stop_token_ids": [],
  "include_stop_str_in_output": false,
  "ignore_eos": false,
  "max_tokens": 200,
  "logprobs": null,
  "prompt_logprobs": null,
  "skip_special_tokens": true,
  "spaces_between_special_tokens": true,
  "schema": {
    "properties": {
      "kind": {
        "title": "Kind",
        "type": "string"
      },
      "color": {
        "title": "Color",
        "type": "string"
      },
      "count": {
        "title": "Count",
        "type": "integer"
      },
      "weight": {
        "title": "Weight",
        "type": "number"
      },
      "sweet": {
        "title": "Sweet",
        "type": "boolean"
      }
    },
    "required": [
      "kind",
      "color",
      "count",
      "weight",
      "sweet"
    ],
    "title": "Fruit",
    "type": "object"
  }
}'
```

output:

```
{"text":["You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n\"kind\": \"Apple\",\n\"color\": \"Red\",\n\"count\": 10,\n\"weight\": 0.2,\n\"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n    \"kind\": \"Apple\",\n    \"color\": \"Red\",\n    \"count\": 10,\n    \"weight\": 0.2,\n    \"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n  \"kind\": \"apple\",\n  \"color\": \"red\",\n  \"count\": 5,\n  \"weight\": 0.1,\n  \"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n\"kind\": \"Apple\",\n\"color\": \"Red\",\n\"count\": 10,\n\"weight\": 0.24,\n\"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n  \"kind\": \"Apple\",\n  \"color\": \"red\",\n  \"count\": 5,\n  \"weight\": 0.3,\n  \"sweet\": true\n}"]}
```

</details>

---------

Co-authored-by: Andrew Lapp <[email protected]>
@rlouf rlouf closed this as completed Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug vLLM Things involving vLLM support
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants
@viktor-ferenczi @lapp0 @rlouf @simon-mo @wdhitchc @BenoitHardier and others