Skip to content

Commit

Permalink
Revert "Handle predictors with deferred annotations (replicate#1772)" (
Browse files Browse the repository at this point in the history
…replicate#1918)

This reverts commit 05900a7.

This is needed to avoid breaking predictors that rely on __signature__ or partial.
  • Loading branch information
technillogue authored Sep 13, 2024
1 parent dbfa22f commit f76e4d5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 65 deletions.
40 changes: 13 additions & 27 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Type,
Union,
cast,
get_type_hints,
)

try:
Expand Down Expand Up @@ -283,18 +282,13 @@ def validate_input_type(
)


def get_input_create_model_kwargs(
signature: inspect.Signature, input_types: Dict[str, Any]
) -> Dict[str, Any]:
def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]:
create_model_kwargs = {}

order = 0

for name, parameter in signature.parameters.items():
if name not in input_types:
raise TypeError(f"No input type provided for parameter `{name}`.")

InputType = input_types[name] # pylint: disable=invalid-name
InputType = parameter.annotation

validate_input_type(InputType, name)

Expand Down Expand Up @@ -360,17 +354,13 @@ class Input(BaseModel):
predict = get_predict(predictor)
signature = inspect.signature(predict)

input_types = get_type_hints(predict)
if "return" in input_types:
del input_types["return"]

return create_model(
"Input",
__config__=None,
__base__=BaseInput,
__module__=__name__,
__validators__=None,
**get_input_create_model_kwargs(signature, input_types),
**get_input_create_model_kwargs(signature),
) # type: ignore


Expand All @@ -380,10 +370,9 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]:
"""

predict = get_predict(predictor)

input_types = get_type_hints(predict)

if "return" not in input_types:
signature = inspect.signature(predict)
OutputType: Type[BaseModel]
if signature.return_annotation is inspect.Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand All @@ -398,7 +387,8 @@ def predict(
...
"""
)
OutputType = input_types.pop("return") # pylint: disable=invalid-name
else:
OutputType = signature.return_annotation

# The type that goes in the response is a list of the yielded type
if get_origin(OutputType) is Iterator:
Expand Down Expand Up @@ -462,17 +452,13 @@ class TrainingInput(BaseModel):
train = get_train(predictor)
signature = inspect.signature(train)

input_types = get_type_hints(train)
if "return" in input_types:
del input_types["return"]

return create_model(
"TrainingInput",
__config__=None,
__base__=BaseInput,
__module__=__name__,
__validators__=None,
**get_input_create_model_kwargs(signature, input_types),
**get_input_create_model_kwargs(signature),
) # type: ignore


Expand All @@ -482,9 +468,9 @@ def get_training_output_type(predictor: BasePredictor) -> Type[BaseModel]:
"""

train = get_train(predictor)
signature = inspect.signature(train)

input_types = get_type_hints(train)
if "return" not in input_types:
if signature.return_annotation is inspect.Signature.empty:
raise TypeError(
"""You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type.
Expand All @@ -499,8 +485,8 @@ def train(
...
"""
)

TrainingOutputType = input_types.pop("return") # pylint: disable=invalid-name
else:
TrainingOutputType = signature.return_annotation

name = (
TrainingOutputType.__name__ if hasattr(TrainingOutputType, "__name__") else ""
Expand Down

This file was deleted.

30 changes: 0 additions & 30 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,33 +288,3 @@ def test_predict_path_list_input(tmpdir_factory):
)
assert "test1" in result.stdout
assert "test2" in result.stdout


def test_predict_works_with_deferred_annotations():
project_dir = Path(__file__).parent / "fixtures/future-annotations-project"

subprocess.check_call(
["cog", "predict", "-i", "input=world"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_int_none_output():
project_dir = Path(__file__).parent / "fixtures/int-none-output-project"

subprocess.check_call(
["cog", "predict"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)


def test_predict_string_none_output():
project_dir = Path(__file__).parent / "fixtures/string-none-output-project"

subprocess.check_call(
["cog", "predict"],
cwd=project_dir,
timeout=DEFAULT_TIMEOUT,
)

0 comments on commit f76e4d5

Please sign in to comment.