From 55b0f1b27c4cc28ecbfd17461189ac16ade2bd06 Mon Sep 17 00:00:00 2001 From: Dan Buch Date: Tue, 4 Feb 2025 23:33:03 -0500 Subject: [PATCH] Adjust `Trainings#async_create` signature to better align with `Trainings.create` and the way that the arguments are being used. Closes #408 --- replicate/training.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/replicate/training.py b/replicate/training.py index 28e28b4a..26981abb 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -307,9 +307,9 @@ def create( # type: ignore async def async_create( self, - model: Union[str, Tuple[str, str], "Model"], - version: Union[str, Version], - input: Dict[str, Any], + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[str, Version]] = None, + input: Optional[Dict[str, Any]] = None, **params: Unpack["Trainings.CreateTrainingParams"], ) -> Training: """ @@ -326,7 +326,15 @@ async def async_create( The training object. """ - url = _create_training_url_from_model_and_version(model, version) + url = None + + if model and version: + url = _create_training_url_from_model_and_version(model, version) + elif model is None and isinstance(version, str): + url = _create_training_url_from_shorthand(version) + + if not url: + raise ValueError("model and version or shorthand version must be specified") file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: