Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhtasham authored Sep 24, 2024
2 parents 1908514 + 538e714 commit 5df3c10
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 9 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/asv_benchmark_pr.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
name: Benchmark PR

on:
push:
pull_request:
branches: [main]
types: [synchronize, labeled]
workflow_dispatch:
env:
PYTHON_VERSION: "3.10"
WORKING_DIR: ${{ github.workspace }}/benchmarks
BENCHMARKS_OUTPUT: ${{ github.workspace }}/benchmarks_output

permissions:
contents: read

# Cancels all previous workflow runs for pull requests that have not completed.
concurrency:
# The concurrency group contains the workflow name and the branch name for pull requests
# or the commit hash for any other events.
group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
cancel-in-progress: true

jobs:
benchmark-pr:
runs-on: ubuntu-latest
if: contains(github.event.pull_request.labels.*.name, 'run_benchmarks') || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_run'
if: ${{ contains(github.event.pull_request.labels.*.name, 'run-benchmarks') || github.ref == 'refs/heads/main' }}

defaults:
run:
Expand Down
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
<img src="./docs/assets/images/logo.png" alt="Outlines Logo" width=500></img>

[![.txt Twitter][dottxt-twitter-badge]][dottxt-twitter]
[![Outlines Twitter][outlines-twitter-badge]][outlines-twitter]

[![Documentation][documentation-badge]][documentation]
[![Contributors][contributors-badge]][contributors]
Expand Down Expand Up @@ -359,10 +358,8 @@ answer = outlines.generate.text(model)(prompt, max_tokens=100)
[contributors]: https://github.com/dottxt-ai/outlines/graphs/contributors
[contributors-badge]: https://img.shields.io/github/contributors/dottxt-ai/outlines?style=flat-square&logo=github&logoColor=white&color=ECEFF4
[dottxt-twitter]: https://twitter.com/dottxtai
[outlines-twitter]: https://twitter.com/OutlinesOSS
[discord]: https://discord.gg/R9DSu34mGd
[discord-badge]: https://img.shields.io/discord/1182316225284554793?color=81A1C1&logo=discord&logoColor=white&style=flat-square
[downloads-badge]: https://img.shields.io/pypi/dm/outlines?color=89AC6B&logo=python&logoColor=white&style=flat-square
[pypistats]: https://pypistats.org/packages/outlines
[dottxt-twitter-badge]: https://img.shields.io/twitter/follow/dottxtai?style=social
[outlines-twitter-badge]: https://img.shields.io/twitter/follow/OutlinesOSS?style=social
9 changes: 5 additions & 4 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

Expand Down Expand Up @@ -503,7 +504,7 @@ def __call__(
completions = self.model.generate(
prompts,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand All @@ -525,7 +526,7 @@ def stream(
return self.model.stream(
prompts,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand Down Expand Up @@ -556,7 +557,7 @@ def __call__( # type: ignore
prompts,
media,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand All @@ -581,7 +582,7 @@ def stream( # type: ignore
prompts,
media,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand Down
10 changes: 10 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ def test_generate_choice(request, model_fixture, sample_choices):
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_choice_twice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
assert res in sample_choices
res = generator(**get_inputs(model_fixture))
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_format_bool(request, model_fixture):
model = request.getfixturevalue(model_fixture)
Expand Down

0 comments on commit 5df3c10

Please sign in to comment.