Skip to content

Commit

Permalink
Keep track of state in RegexLogitsProcessor using input_ids (#628)
Browse files Browse the repository at this point in the history
For `outlines/vllm` previously FSM-sequence correspondence was broken,
resulting FSM state being mixed between sequences, corrupting output. To
alleviate this, we have `_patched_apply_logits_processor` which passes a
stable sequence ID to the logits processor.

In this PR we eliminate `_patched_apply_logits_processor` and cache FSM
state based on the states input IDs.

Continuation of #539 but
much simpler because vllm upgrade fixed a lot of the issues being
addressed there.

Related discussions:
- #624

Fixes:
- Fixes #605 
- Fixes #610

Already fixed:
- #524 (this one can be
closed, as it's was addressed previously by upgrading vllm)


@viktor-ferenczi can you please confirm whether this branch fixes either
#610 or
#605

# Smoke tests

### basic parallel

passed

<details>

```
import json
import vllm
from pydantic import BaseModel
from typing import List
import torch
import pandas as pd
from outlines.serve.vllm import JSONLogitsProcessor

class ConceptsList(BaseModel):
    concepts: List[str]

BASE_MODEL = "microsoft/phi-2"
llm = vllm.LLM(model=BASE_MODEL, tensor_parallel_size=1, dtype=torch.float16, max_model_len=2048)

logits_processor = JSONLogitsProcessor(ConceptsList, llm.llm_engine)

full_prompts = [
    f"Provide me a list of {i} strings with key 'concepts'"
    for i in range(20)
]

batch_results = llm.generate(
    full_prompts,
    sampling_params=vllm.SamplingParams(
        max_tokens=2048, logits_processors=[logits_processor]
    ),
)


for result in batch_results:
    for output in result.outputs:
            json.loads(output.text)
```

</details>


### never ending regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
        "prompt": "Sequence of numbers and letters:",
        "regex": "([123]-[abc]-([def]-)?)*",
        "n": 7
}'
{"text":["Sequence of numbers and letters:1-a-1-b-1-c-1-a-","Sequence of numbers and letters:1-a-2-b-3-c-1-a-","Sequence of numbers and letters:1-a-2-b-3-c-d-1-","Sequence of numbers and letters:2-a-1-b-2-c-1-b-","Sequence of numbers and letters:2-b-3-c-d-2-b-3-","Sequence of numbers and letters:2-a-3-b-2-b-1-c-","Sequence of numbers and letters:2-a-3-b-d-2-a-3-"]}


# rules for the above to validate correct FSM-sequence correspondence:
# [123] always followed by [abc], [def] only ever preceded by [abc]

# 1-a-1-b-1-c-1-a-
# 1-a-2-b-3-c-1-a-
# 1-a-2-b-3-c-d-1-
# 2-a-1-b-2-c-1-b-
# 2-b-3-c-d-2-b-3-
# 2-a-3-b-2-b-1-c-
# 2-a-3-b-d-2-a-3-
```

</details>


### sometimes ending early regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
        "prompt": "Sequence of numbers and letters:",
        "regex": "([123]-[abc]-([def]-)?){3}",
        "n": 16
}'
```

output

```
{"text":["Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:1-a-2-b-3-c-d-","Sequence of numbers and letters:3-a-1-b-2-c-d-","Sequence of numbers and letters:2-a-1-b-3-c-d-","Sequence of numbers and letters:1-a-1-b-1-c-d-","Sequence of numbers and letters:2-a-3-b-d-1-c-e-","Sequence of numbers and letters:1-b-3-a-2-c-d-","Sequence of numbers and letters:3-a-d-1-b-e-2-c-","Sequence of numbers and letters:1-a-3-b-1-b-d-","Sequence of numbers and letters:3-a-f-2-b-d-1-c-","Sequence of numbers and letters:1-b-d-3-a-e-2-c-","Sequence of numbers and letters:3-c-1-b-d-1-a-e-","Sequence of numbers and letters:1-c-1-c-e-1-b-e-"]}
```

analysis:

```
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
1-a-2-b-3-c-d-
3-a-1-b-2-c-d-
2-a-1-b-3-c-d-
1-a-1-b-1-c-d-
2-a-3-b-d-1-c-e-
1-b-3-a-2-c-d-
3-a-d-1-b-e-2-c-
1-a-3-b-1-b-d-
3-a-f-2-b-d-1-c-
1-b-d-3-a-e-2-c-
3-c-1-b-d-1-a-e-
1-c-1-c-e-1-b-e-
```

Observations:
- All patterns are correct
- Patterns don't "borrow" FSM state from one-another, they retain their
own independent state
- Some patterns produced more tokens than others successfully


</details>

### Viktor's regex

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
  "prompt": "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list, starting with 2.\n\n### Response:\n",
  "n": 1,
  "best_of": 1,
  "presence_penalty": 0.0,
  "frequency_penalty": 0.0,
  "repetition_penalty": 1.0,
  "temperature": 0.0,
  "top_p": 1.0,
  "top_k": -1,
  "min_p": 0.0,
  "use_beam_search": false,
  "length_penalty": 1.0,
  "early_stopping": false,
  "stop": [],
  "stop_token_ids": [],
  "include_stop_str_in_output": false,
  "ignore_eos": false,
  "max_tokens": 50,
  "logprobs": null,
  "prompt_logprobs": null,
  "skip_special_tokens": true,
  "spaces_between_special_tokens": true,
  "regex": "\\d+(\\s*,\\s*\\d+)*\\s*"
}'
```

output:

```
{"text":["You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite down the first 10 prime numbers as a comma separated list, starting with 2.\n\n### Response:\n2, 3, 5, 7, 11, 13, 17, 19, 23, 29\n"]}
```

</details>

### Viktors schema

passed

<details>

`python3 -m outlines.serve.serve --model="microsoft/phi-2"`

```
curl http://127.0.0.1:8000/generate \
    -d '{
  "prompt": "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n",
  "n": 5,
  "best_of": 5,
  "presence_penalty": 0.0,
  "frequency_penalty": 0.0,
  "repetition_penalty": 1.0,
  "temperature": 1.0,
  "top_p": 1.0,
  "top_k": -1,
  "min_p": 0.0,
  "use_beam_search": false,
  "length_penalty": 1.0,
  "early_stopping": false,
  "stop": [],
  "stop_token_ids": [],
  "include_stop_str_in_output": false,
  "ignore_eos": false,
  "max_tokens": 200,
  "logprobs": null,
  "prompt_logprobs": null,
  "skip_special_tokens": true,
  "spaces_between_special_tokens": true,
  "schema": {
    "properties": {
      "kind": {
        "title": "Kind",
        "type": "string"
      },
      "color": {
        "title": "Color",
        "type": "string"
      },
      "count": {
        "title": "Count",
        "type": "integer"
      },
      "weight": {
        "title": "Weight",
        "type": "number"
      },
      "sweet": {
        "title": "Sweet",
        "type": "boolean"
      }
    },
    "required": [
      "kind",
      "color",
      "count",
      "weight",
      "sweet"
    ],
    "title": "Fruit",
    "type": "object"
  }
}'
```

output:

```
{"text":["You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n\"kind\": \"Apple\",\n\"color\": \"Red\",\n\"count\": 10,\n\"weight\": 0.2,\n\"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n    \"kind\": \"Apple\",\n    \"color\": \"Red\",\n    \"count\": 10,\n    \"weight\": 0.2,\n    \"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n  \"kind\": \"apple\",\n  \"color\": \"red\",\n  \"count\": 5,\n  \"weight\": 0.1,\n  \"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n\"kind\": \"Apple\",\n\"color\": \"Red\",\n\"count\": 10,\n\"weight\": 0.24,\n\"sweet\": true\n}","You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\nYou are a helpful AI assistant. You give concise answers. If you do not know something, then say so.\n### Instruction:\nWrite a JSON describing a random fruit. It must conform to the following JSON schema: {\"properties\": {\"kind\": {\"title\": \"Kind\", \"type\": \"string\"}, \"color\": {\"title\": \"Color\", \"type\": \"string\"}, \"count\": {\"title\": \"Count\", \"type\": \"integer\"}, \"weight\": {\"title\": \"Weight\", \"type\": \"number\"}, \"sweet\": {\"title\": \"Sweet\", \"type\": \"boolean\"}}, \"required\": [\"kind\", \"color\", \"count\", \"weight\", \"sweet\"], \"title\": \"Fruit\", \"type\": \"object\"}\n\n### Response:\n{\n  \"kind\": \"Apple\",\n  \"color\": \"red\",\n  \"count\": 5,\n  \"weight\": 0.3,\n  \"sweet\": true\n}"]}
```

</details>

---------

Co-authored-by: Andrew Lapp <[email protected]>
  • Loading branch information
lapp0 and Andrew Lapp authored Feb 14, 2024
1 parent a33692e commit b89e8df
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 43 deletions.
11 changes: 1 addition & 10 deletions outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@
from typing import AsyncGenerator

import uvicorn
import vllm.model_executor.layers.sampler as sampler
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from .vllm import (
JSONLogitsProcessor,
RegexLogitsProcessor,
_patched_apply_logits_processors,
)

# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors

from .vllm import JSONLogitsProcessor, RegexLogitsProcessor

TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
Expand Down
38 changes: 5 additions & 33 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,6 @@
from outlines.fsm.json_schema import build_regex_from_object


def _patched_apply_logits_processors(
logits,
sampling_metadata,
):
"""Patch vLLM's logit processor.
We need to patch the logits processor to pass the `seq_id` so we can
handle several sequences in `JSONLogitsProcessor`
"""
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(seq_id, token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


class RegexLogitsProcessor:
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-structured generation.
Expand All @@ -56,17 +27,18 @@ def __init__(self, regex_string, llm):
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

seq_id = hash(tuple(input_ids))

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
else:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
self.fsm_state[last_seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
Expand Down

0 comments on commit b89e8df

Please sign in to comment.