diff --git a/replicate/training.py b/replicate/training.py index 28e28b4..26981ab 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -307,9 +307,9 @@ def create( # type: ignore async def async_create( self, - model: Union[str, Tuple[str, str], "Model"], - version: Union[str, Version], - input: Dict[str, Any], + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[str, Version]] = None, + input: Optional[Dict[str, Any]] = None, **params: Unpack["Trainings.CreateTrainingParams"], ) -> Training: """ @@ -326,7 +326,15 @@ async def async_create( The training object. """ - url = _create_training_url_from_model_and_version(model, version) + url = None + + if model and version: + url = _create_training_url_from_model_and_version(model, version) + elif model is None and isinstance(version, str): + url = _create_training_url_from_shorthand(version) + + if not url: + raise ValueError("model and version or shorthand version must be specified") file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: