Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix partial wrappers and deferred annotations #1895

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterator
from inspect import Signature
from pathlib import Path
from typing import (
Any,
Expand All @@ -18,6 +19,7 @@
Type,
Union,
cast,
get_type_hints,
)

try:
Expand Down Expand Up @@ -266,7 +268,7 @@ def validate_input_type(
type: Type[Any], # pylint: disable=redefined-builtin
name: str,
) -> None:
if type is inspect.Signature.empty:
if type is Signature.empty:
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
)
Expand All @@ -282,18 +284,20 @@ def validate_input_type(
)


def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]:
def get_input_create_model_kwargs(
signature: Signature, input_types: Dict[str, Any]
) -> Dict[str, Any]:
create_model_kwargs = {}

order = 0

for name, parameter in signature.parameters.items():
InputType = parameter.annotation
InputType = input_types.get(name, parameter.annotation)

validate_input_type(InputType, name)

# if no default is specified, create an empty, required input
if parameter.default is inspect.Signature.empty:
if parameter.default is Signature.empty:
default = Input()
else:
default = parameter.default
Expand Down Expand Up @@ -354,25 +358,39 @@ class Input(BaseModel):
predict = get_predict(predictor)
signature = inspect.signature(predict)

try:
input_types = get_type_hints(predict)
except TypeError:
input_types = {}
if "return" in input_types:
del input_types["return"]

return create_model(
"Input",
__config__=None,
__base__=BaseInput,
__module__=__name__,
__validators__=None,
**get_input_create_model_kwargs(signature),
**get_input_create_model_kwargs(signature, input_types),
) # type: ignore


def get_return_annotation(fn: Callable[..., Any]) -> "Type[Any]|Signature.empty":
try:
return get_type_hints(fn).get("return", Signature.empty)
except TypeError:
return inspect.signature(fn).return_annotation


def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
"""
Creates a Pydantic Output model from the return type annotation of a Predictor's predict() method.
"""

predict = get_predict(predictor)
signature = inspect.signature(predict)
OutputType: Type[BaseModel]
if signature.return_annotation is inspect.Signature.empty:
maybe_output_type = get_return_annotation(predict)

if maybe_output_type is Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.

Expand All @@ -387,9 +405,8 @@ def predict(
...
"""
)
else:
OutputType = signature.return_annotation

# we need the indirection to narrow the type to the one that's declared later
OutputType = maybe_output_type # pylint: disable=invalid-name
# The type that goes in the response is a list of the yielded type
if get_origin(OutputType) is Iterator:
# Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator
Expand Down Expand Up @@ -452,13 +469,20 @@ class TrainingInput(BaseModel):
train = get_train(predictor)
signature = inspect.signature(train)

try:
input_types = get_type_hints(train)
except TypeError:
input_types = {}
if "return" in input_types:
del input_types["return"]

return create_model(
"TrainingInput",
__config__=None,
__base__=BaseInput,
__module__=__name__,
__validators__=None,
**get_input_create_model_kwargs(signature),
**get_input_create_model_kwargs(signature, input_types),
) # type: ignore


Expand All @@ -468,9 +492,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:
"""

train = get_train(predictor)
signature = inspect.signature(train)

if signature.return_annotation is inspect.Signature.empty:
TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name
if TrainingOutputType is Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.

Expand All @@ -485,8 +509,6 @@ def train(
...
"""
)
else:
TrainingOutputType = signature.return_annotation

name = (
TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(self, input: str = Input(description="Who to greet")) -> str:
return "hello " + input
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
train: "predict.py:train"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import functools
import inspect
from typing import Any, Callable

from cog import BasePredictor, Input


class Predictor(BasePredictor):
def general(
self, prompt: str = Input(description="hi"), system_prompt: str = None
) -> int:
return 1

def _remove(f: Callable, defaults: "dict[str, Any]") -> Callable:
# pylint: disable=no-self-argument
def wrapper(self, *args, **kwargs):
kwargs.update(defaults)
return f(self, *args, **kwargs)

# Update wrapper attributes for documentation, etc.
functools.update_wrapper(wrapper, f)

# for the purposes of inspect.signature as used by predictor.get_input_type,
# remove the argument (system_prompt)
sig = inspect.signature(f)
params = [p for name, p in sig.parameters.items() if name not in defaults]
wrapper.__signature__ = sig.replace(parameters=params)

# Return partialmethod, wrapper behaves correctly when part of a class
return functools.partialmethod(wrapper)

predict = _remove(general, {"system_prompt": ""})


def _train(prompt: str = Input(description="hi"), system_prompt: str = None) -> int:
return 1


train = functools.partial(_train, system_prompt="")
45 changes: 45 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,48 @@ def test_predict_path_list_input(tmpdir_factory):
)
assert "test1" in result.stdout
assert "test2" in result.stdout


def test_predict_works_with_deferred_annotations():
project_dir = Path(__file__).parent / "fixtures/future-annotations-project"

subprocess.check_call(
["cog", "predict", "-i", "input=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_works_with_partial_wrapper():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be an integration test? Could it be a worker test instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just following #1772, should test_predict_works_with_deferred_annotations be moved as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. I'd be tempted to move both of those into test_worker if it's possible to reproduce the original failure there. It's not really clear to me what the original failure was, though, so it's possible that doesn't work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://replicatehq.slack.com/archives/C05600FDYTE/p1723843667345019

    **get_input_create_model_kwargs(signature, input_types),
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/cog/predictor.py", line 306, in get_input_create_model_kwargs
    raise TypeError(f"No input type provided for parameter `{name}`.")
TypeError: No input type provided for parameter `prompt`.

and

  File "/root/.pyenv/versions/3.11.7/lib/python3.11/site-packages/cog/predictor.py", line 374, in get_input_type
    input_types = get_type_hints(predict)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.7/lib/python3.11/typing.py", line 2381, in get_type_hints
    raise TypeError('{!r} is not a module, class, method, '
TypeError: functools.partial(<bound method Predictor.predict of <predict.Predictor object at 0x70abe38f0490>>) is not a module, class, method, or function.

Copy link
Contributor Author

@technillogue technillogue Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to at least be a predictor test because I don't think worker calls get_input_type, though a schema test would work too

project_dir = Path(__file__).parent / "fixtures/partial-predict-project"

subprocess.check_call(
["cog", "predict", "-i", "prompt=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)
subprocess.check_call(
["cog", "train", "-i", "prompt=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_int_none_output():
project_dir = Path(__file__).parent / "fixtures/int-none-output-project"

subprocess.check_call(
["cog", "predict"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_string_none_output():
project_dir = Path(__file__).parent / "fixtures/string-none-output-project"

subprocess.check_call(
["cog", "predict"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)
Loading