Skip to content

Commit

Permalink
add test for partial
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Aug 21, 2024
1 parent 8c50fea commit f594145
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
train: "predict.py:train"
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Callable
import functools

from cog import BasePredictor, Input


class Predictor(BasePredictor):
def general(
self, prompt: str = Input(description="hi"), system_prompt: str = None
) -> int:
return 1

def _remove(f: Callable, defaults: dict[str, Any]) -> Callable:
# pylint: disable=no-self-argument
def wrapper(self, *args, **kwargs):
kwargs.update(defaults)
return f(self, *args, **kwargs)

# Update wrapper attributes for documentation, etc.
functools.update_wrapper(wrapper, f)

# for the purposes of inspect.signature as used by predictor.get_input_type,
# remove the argument (system_prompt)
sig = inspect.signature(f)
params = [p for name, p in sig.parameters.items() if name not in defaults]
wrapper.__signature__ = sig.replace(parameters=params)

# Return partialmethod, wrapper behaves correctly when part of a class
return functools.partialmethod(wrapper)

predict = _remove(general, {"system_prompt": ""})


def _train(self, prompt: str = Input(description="hi"), system_prompt: str = None):
return 1


train = functools.partial(_train, system_prompt="")
14 changes: 14 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,20 @@ def test_predict_works_with_deferred_annotations():
timeout=DEFAULT_TIMEOUT,
)

def test_predict_works_with_partial_wrapper():
project_dir = Path(__file__).parent / "fixtures/partial-predict-project"

subprocess.check_call(
["cog", "predict", "-i", "prompt=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)
subprocess.check_call(
["cog", "train", "-i", "prompt=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_int_none_output():
project_dir = Path(__file__).parent / "fixtures/int-none-output-project"
Expand Down

0 comments on commit f594145

Please sign in to comment.