From 09d81a32fce7a4655dddcd2c2d84f10142d8cb1b Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 13 Sep 2024 16:54:24 -0400 Subject: [PATCH 1/4] Revert "Revert "Handle predictors with deferred annotations (#1772)" (#1918)" This reverts commit f76e4d570b9f10908cd57255e132cfdd3ebce6ef. --- python/cog/predictor.py | 40 +++++++++++++------ .../future-annotations-project/predict.py | 8 ++++ .../test_integration/test_predict.py | 30 ++++++++++++++ 3 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 test-integration/test_integration/fixtures/future-annotations-project/predict.py diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 50754f2ac7..a2dd4d3ea0 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -18,6 +18,7 @@ Type, Union, cast, + get_type_hints, ) try: @@ -282,13 +283,18 @@ def validate_input_type( ) -def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]: +def get_input_create_model_kwargs( + signature: inspect.Signature, input_types: Dict[str, Any] +) -> Dict[str, Any]: create_model_kwargs = {} order = 0 for name, parameter in signature.parameters.items(): - InputType = parameter.annotation + if name not in input_types: + raise TypeError(f"No input type provided for parameter `{name}`.") + + InputType = input_types[name] # pylint: disable=invalid-name validate_input_type(InputType, name) @@ -354,13 +360,17 @@ class Input(BaseModel): predict = get_predict(predictor) signature = inspect.signature(predict) + input_types = get_type_hints(predict) + if "return" in input_types: + del input_types["return"] + return create_model( "Input", __config__=None, __base__=BaseInput, __module__=__name__, __validators__=None, - **get_input_create_model_kwargs(signature), + **get_input_create_model_kwargs(signature, input_types), ) # type: ignore @@ -370,9 +380,10 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ predict = get_predict(predictor) - signature = inspect.signature(predict) - OutputType: Type[BaseModel] - if signature.return_annotation is inspect.Signature.empty: + + input_types = get_type_hints(predict) + + if "return" not in input_types: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -387,8 +398,7 @@ def predict( ... """ ) - else: - OutputType = signature.return_annotation + OutputType = input_types.pop("return") # pylint: disable=invalid-name # The type that goes in the response is a list of the yielded type if get_origin(OutputType) is Iterator: @@ -452,13 +462,17 @@ class TrainingInput(BaseModel): train = get_train(predictor) signature = inspect.signature(train) + input_types = get_type_hints(train) + if "return" in input_types: + del input_types["return"] + return create_model( "TrainingInput", __config__=None, __base__=BaseInput, __module__=__name__, __validators__=None, - **get_input_create_model_kwargs(signature), + **get_input_create_model_kwargs(signature, input_types), ) # type: ignore @@ -468,9 +482,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ train = get_train(predictor) - signature = inspect.signature(train) - if signature.return_annotation is inspect.Signature.empty: + input_types = get_type_hints(train) + if "return" not in input_types: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -485,8 +499,8 @@ def train( ... """ ) - else: - TrainingOutputType = signature.return_annotation + + TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name name = ( TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else "" diff --git a/test-integration/test_integration/fixtures/future-annotations-project/predict.py b/test-integration/test_integration/fixtures/future-annotations-project/predict.py new file mode 100644 index 0000000000..791d2218fd --- /dev/null +++ b/test-integration/test_integration/fixtures/future-annotations-project/predict.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, input: str = Input(description="Who to greet")) -> str: + return "hello " + input diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index d6da057a96..34b81d3891 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -288,3 +288,33 @@ def test_predict_path_list_input(tmpdir_factory): ) assert "test1" in result.stdout assert "test2" in result.stdout + + +def test_predict_works_with_deferred_annotations(): + project_dir = Path(__file__).parent / "fixtures/future-annotations-project" + + subprocess.check_call( + ["cog", "predict", "-i", "input=world"], + cwd=project_dir, + timeout=DEFAULT_TIMEOUT, + ) + + +def test_predict_int_none_output(): + project_dir = Path(__file__).parent / "fixtures/int-none-output-project" + + subprocess.check_call( + ["cog", "predict"], + cwd=project_dir, + timeout=DEFAULT_TIMEOUT, + ) + + +def test_predict_string_none_output(): + project_dir = Path(__file__).parent / "fixtures/string-none-output-project" + + subprocess.check_call( + ["cog", "predict"], + cwd=project_dir, + timeout=DEFAULT_TIMEOUT, + ) From 1d8dea9a2b5c9840b630d6bda3a3c36f2fd2a011 Mon Sep 17 00:00:00 2001 From: technillogue Date: Wed, 21 Aug 2024 19:15:07 -0400 Subject: [PATCH 2/4] add test for partial --- .../fixtures/partial-predict-project/cog.yaml | 4 ++ .../partial-predict-project/predict.py | 38 +++++++++++++++++++ .../test_integration/test_predict.py | 14 +++++++ 3 files changed, 56 insertions(+) create mode 100644 test-integration/test_integration/fixtures/partial-predict-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/partial-predict-project/predict.py diff --git a/test-integration/test_integration/fixtures/partial-predict-project/cog.yaml b/test-integration/test_integration/fixtures/partial-predict-project/cog.yaml new file mode 100644 index 0000000000..90d3b36b63 --- /dev/null +++ b/test-integration/test_integration/fixtures/partial-predict-project/cog.yaml @@ -0,0 +1,4 @@ +build: + python_version: "3.8" +predict: "predict.py:Predictor" +train: "predict.py:train" diff --git a/test-integration/test_integration/fixtures/partial-predict-project/predict.py b/test-integration/test_integration/fixtures/partial-predict-project/predict.py new file mode 100644 index 0000000000..0e24def50c --- /dev/null +++ b/test-integration/test_integration/fixtures/partial-predict-project/predict.py @@ -0,0 +1,38 @@ +from typing import Callable +import functools + +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def general( + self, prompt: str = Input(description="hi"), system_prompt: str = None + ) -> int: + return 1 + + def _remove(f: Callable, defaults: dict[str, Any]) -> Callable: + # pylint: disable=no-self-argument + def wrapper(self, *args, **kwargs): + kwargs.update(defaults) + return f(self, *args, **kwargs) + + # Update wrapper attributes for documentation, etc. + functools.update_wrapper(wrapper, f) + + # for the purposes of inspect.signature as used by predictor.get_input_type, + # remove the argument (system_prompt) + sig = inspect.signature(f) + params = [p for name, p in sig.parameters.items() if name not in defaults] + wrapper.__signature__ = sig.replace(parameters=params) + + # Return partialmethod, wrapper behaves correctly when part of a class + return functools.partialmethod(wrapper) + + predict = _remove(general, {"system_prompt": ""}) + + +def _train(self, prompt: str = Input(description="hi"), system_prompt: str = None): + return 1 + + +train = functools.partial(_train, system_prompt="") diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 34b81d3891..31e1a1085b 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -299,6 +299,20 @@ def test_predict_works_with_deferred_annotations(): timeout=DEFAULT_TIMEOUT, ) +def test_predict_works_with_partial_wrapper(): + project_dir = Path(__file__).parent / "fixtures/partial-predict-project" + + subprocess.check_call( + ["cog", "predict", "-i", "prompt=world"], + cwd=project_dir, + timeout=DEFAULT_TIMEOUT, + ) + subprocess.check_call( + ["cog", "train", "-i", "prompt=world"], + cwd=project_dir, + timeout=DEFAULT_TIMEOUT, + ) + def test_predict_int_none_output(): project_dir = Path(__file__).parent / "fixtures/int-none-output-project" From 772052399f7fb0791570605f67e6f287bb304140 Mon Sep 17 00:00:00 2001 From: technillogue Date: Wed, 21 Aug 2024 19:04:20 -0400 Subject: [PATCH 3/4] fix defered annotations --- python/cog/predictor.py | 40 ++++++++++++------- .../partial-predict-project/predict.py | 7 ++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a2dd4d3ea0..45056e8be8 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -291,10 +291,7 @@ def get_input_create_model_kwargs( order = 0 for name, parameter in signature.parameters.items(): - if name not in input_types: - raise TypeError(f"No input type provided for parameter `{name}`.") - - InputType = input_types[name] # pylint: disable=invalid-name + InputType = input_types.get(name, parameter.annotation) validate_input_type(InputType, name) @@ -360,7 +357,10 @@ class Input(BaseModel): predict = get_predict(predictor) signature = inspect.signature(predict) - input_types = get_type_hints(predict) + try: + input_types = get_type_hints(predict) + except TypeError: + input_types = {} if "return" in input_types: del input_types["return"] @@ -374,16 +374,25 @@ class Input(BaseModel): ) # type: ignore +def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]: + try: + return get_type_hints(fn).get("return", None) + except TypeError: + return_annotation = inspect.signature(fn).return_annotation + if return_annotation is inspect.Signature.empty: + return None + return return_annotation + + def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ Creates a Pydantic Output model from the return type annotation of a Predictor's predict() method. """ predict = get_predict(predictor) + maybe_output_type = get_return_annotation(predict) - input_types = get_type_hints(predict) - - if "return" not in input_types: + if maybe_output_type is None: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -398,8 +407,8 @@ def predict( ... """ ) - OutputType = input_types.pop("return") # pylint: disable=invalid-name - + # we need the indirection to narrow the type to the one that's declared later + OutputType = maybe_output_type # pylint: disable=invalid-name # The type that goes in the response is a list of the yielded type if get_origin(OutputType) is Iterator: # Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator @@ -462,7 +471,10 @@ class TrainingInput(BaseModel): train = get_train(predictor) signature = inspect.signature(train) - input_types = get_type_hints(train) + try: + input_types = get_type_hints(train) + except TypeError: + input_types = {} if "return" in input_types: del input_types["return"] @@ -483,8 +495,8 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: train = get_train(predictor) - input_types = get_type_hints(train) - if "return" not in input_types: + TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name + if not TrainingOutputType: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -500,8 +512,6 @@ def train( """ ) - TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name - name = ( TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else "" ) diff --git a/test-integration/test_integration/fixtures/partial-predict-project/predict.py b/test-integration/test_integration/fixtures/partial-predict-project/predict.py index 0e24def50c..8665351cf7 100644 --- a/test-integration/test_integration/fixtures/partial-predict-project/predict.py +++ b/test-integration/test_integration/fixtures/partial-predict-project/predict.py @@ -1,5 +1,6 @@ -from typing import Callable import functools +import inspect +from typing import Any, Callable from cog import BasePredictor, Input @@ -10,7 +11,7 @@ def general( ) -> int: return 1 - def _remove(f: Callable, defaults: dict[str, Any]) -> Callable: + def _remove(f: Callable, defaults: "dict[str, Any]") -> Callable: # pylint: disable=no-self-argument def wrapper(self, *args, **kwargs): kwargs.update(defaults) @@ -31,7 +32,7 @@ def wrapper(self, *args, **kwargs): predict = _remove(general, {"system_prompt": ""}) -def _train(self, prompt: str = Input(description="hi"), system_prompt: str = None): +def _train(prompt: str = Input(description="hi"), system_prompt: str = None) -> int: return 1 From 77893858792655ef2b9b08aba3bd04fd3a6d12bb Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 30 Aug 2024 14:19:18 -0400 Subject: [PATCH 4/4] tweaks to separate -> None from -> Signature.empty --- python/cog/predictor.py | 20 +++++++++---------- .../test_integration/test_predict.py | 1 + 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 45056e8be8..9f0b2ad778 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -8,6 +8,7 @@ import uuid from abc import ABC, abstractmethod from collections.abc import Iterator +from inspect import Signature from pathlib import Path from typing import ( Any, @@ -267,7 +268,7 @@ def validate_input_type( type: Type[Any], # pylint: disable=redefined-builtin name: str, ) -> None: - if type is inspect.Signature.empty: + if type is Signature.empty: raise TypeError( f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) @@ -284,7 +285,7 @@ def validate_input_type( def get_input_create_model_kwargs( - signature: inspect.Signature, input_types: Dict[str, Any] + signature: Signature, input_types: Dict[str, Any] ) -> Dict[str, Any]: create_model_kwargs = {} @@ -296,7 +297,7 @@ def get_input_create_model_kwargs( validate_input_type(InputType, name) # if no default is specified, create an empty, required input - if parameter.default is inspect.Signature.empty: + if parameter.default is Signature.empty: default = Input() else: default = parameter.default @@ -374,14 +375,11 @@ class Input(BaseModel): ) # type: ignore -def get_return_annotation(fn: Callable[..., Any]) -> Optional[Type[Any]]: +def get_return_annotation(fn: Callable[..., Any]) -> "Type[Any]|Signature.empty": try: - return get_type_hints(fn).get("return", None) + return get_type_hints(fn).get("return", Signature.empty) except TypeError: - return_annotation = inspect.signature(fn).return_annotation - if return_annotation is inspect.Signature.empty: - return None - return return_annotation + return inspect.signature(fn).return_annotation def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: @@ -392,7 +390,7 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: predict = get_predict(predictor) maybe_output_type = get_return_annotation(predict) - if maybe_output_type is None: + if maybe_output_type is Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -496,7 +494,7 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]: train = get_train(predictor) TrainingOutputType = get_return_annotation(train) # pylint: disable=invalid-name - if not TrainingOutputType: + if TrainingOutputType is Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 31e1a1085b..2bb9295661 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -299,6 +299,7 @@ def test_predict_works_with_deferred_annotations(): timeout=DEFAULT_TIMEOUT, ) + def test_predict_works_with_partial_wrapper(): project_dir = Path(__file__).parent / "fixtures/partial-predict-project"