From 004343441ce77602e5d0be536b50ea5eac330f8e Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Tue, 5 Apr 2022 12:08:51 -0500 Subject: [PATCH] Cherry-picked instantiate improvements (#2120) --- hydra/_internal/instantiate/_instantiate2.py | 139 +++++++++---- hydra/_internal/utils.py | 81 ++++---- hydra/utils.py | 17 +- news/1950.feature | 1 + news/2099.feature | 1 + tests/instantiate/__init__.py | 15 +- tests/instantiate/import_error.py | 2 + .../module_shadowed_by_function.py | 3 + tests/instantiate/test_helpers.py | 195 +++++++++++++++++- tests/instantiate/test_instantiate.py | 184 +++++++++++++++-- tests/instantiate/test_positional.py | 41 +++- 11 files changed, 571 insertions(+), 108 deletions(-) create mode 100644 news/1950.feature create mode 100644 news/2099.feature create mode 100644 tests/instantiate/import_error.py create mode 100644 tests/instantiate/module_shadowed_by_function.py diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index 43dc0857fdd..49a90e6c237 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -2,9 +2,9 @@ import copy import functools -import sys from enum import Enum -from typing import Any, Callable, Sequence, Tuple, Union +from textwrap import dedent +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from omegaconf import OmegaConf, SCMode from omegaconf._utils import is_structured_config @@ -32,7 +32,7 @@ def _is_target(x: Any) -> bool: return False -def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]: +def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]: config_args = kwargs.pop(_Keys.ARGS, ()) output_args = config_args @@ -41,16 +41,16 @@ def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]: output_args = input_args else: raise InstantiationException( - f"Unsupported _args_ type: {type(config_args).__name__}. value: {config_args}" + f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" ) return output_args, kwargs -def _call_target(_target_: Callable, _partial_: bool, *args, **kwargs) -> Any: # type: ignore +def _call_target(_target_: Callable, _partial_: bool, args, kwargs, full_key) -> Any: # type: ignore """Call target (type) with args and kwargs.""" try: - args, kwargs = _extract_pos_args(*args, **kwargs) + args, kwargs = _extract_pos_args(args, kwargs) # detaching configs from parent. # At this time, everything is resolved and the parent link can cause # issues when serializing objects in some scenarios. @@ -60,13 +60,35 @@ def _call_target(_target_: Callable, _partial_: bool, *args, **kwargs) -> Any: for v in kwargs.values(): if OmegaConf.is_config(v): v._set_parent(None) - if _partial_: - return functools.partial(_target_, *args, **kwargs) - return _target_(*args, **kwargs) except Exception as e: - raise type(e)( - f"Error instantiating '{_convert_target_to_string(_target_)}' : {e}" - ).with_traceback(sys.exc_info()[2]) + msg = ( + f"Error in collecting args and kwargs for '{_convert_target_to_string(_target_)}':" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + + raise InstantiationException(msg) from e + + if _partial_: + try: + return functools.partial(_target_, *args, **kwargs) + except Exception as e: + msg = ( + f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + else: + try: + return _target_(*args, **kwargs) + except Exception as e: + msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e def _convert_target_to_string(t: Any) -> Any: @@ -78,7 +100,7 @@ def _convert_target_to_string(t: Any) -> Any: return t -def _prepare_input_dict(d: Any) -> Any: +def _prepare_input_dict_or_list(d: Union[Dict[Any, Any], List[Any]]) -> Any: res: Any if isinstance(d, dict): res = {} @@ -86,13 +108,13 @@ def _prepare_input_dict(d: Any) -> Any: if k == "_target_": v = _convert_target_to_string(d["_target_"]) elif isinstance(v, (dict, list)): - v = _prepare_input_dict(v) + v = _prepare_input_dict_or_list(v) res[k] = v elif isinstance(d, list): res = [] for v in d: if isinstance(v, (list, dict)): - v = _prepare_input_dict(v) + v = _prepare_input_dict_or_list(v) res.append(v) else: assert False @@ -100,18 +122,23 @@ def _prepare_input_dict(d: Any) -> Any: def _resolve_target( - target: Union[str, type, Callable[..., Any]] + target: Union[str, type, Callable[..., Any]], full_key: str ) -> Union[type, Callable[..., Any]]: """Resolve target string, type or callable into type or callable.""" if isinstance(target, str): - return _locate(target) - if isinstance(target, type): - return target - if callable(target): - return target - raise InstantiationException( - f"Unsupported target type: {type(target).__name__}. value: {target}" - ) + try: + target = _locate(target) + except Exception as e: + msg = f"Error locating target '{target}', see chained exception above." + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + if not callable(target): + msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + return target def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: @@ -151,17 +178,23 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( - f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." - f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" + dedent( + f"""\ + Config has missing value for key `_target_`, cannot instantiate. + Config type: {type(config).__name__} + Check that the `_target_` key in your dataclass is properly annotated and overridden. + A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'""" + ) ) + # TODO: print full key - if isinstance(config, dict): - config = _prepare_input_dict(config) + if isinstance(config, (dict, list)): + config = _prepare_input_dict_or_list(config) - kwargs = _prepare_input_dict(kwargs) + kwargs = _prepare_input_dict_or_list(kwargs) # Structured Config always converted first to OmegaConf - if is_structured_config(config) or isinstance(config, dict): + if is_structured_config(config) or isinstance(config, (dict, list)): config = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_dict(config): @@ -182,12 +215,40 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE) _partial_ = config.pop(_Keys.PARTIAL, False) + return instantiate_node( + config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + ) + elif OmegaConf.is_list(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + OmegaConf.resolve(config) + + _recursive_ = kwargs.pop(_Keys.RECURSIVE, True) + _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE) + _partial_ = kwargs.pop(_Keys.PARTIAL, False) + + if _partial_: + raise InstantiationException( + "The _partial_ keyword is not compatible with top-level list instantiation" + ) + return instantiate_node( config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ ) else: raise InstantiationException( - "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance" + dedent( + f"""\ + Cannot instantiate config of type {type(config).__name__}. + Top level config must be an OmegaConf DictConfig/ListConfig object, + a plain dict/list, or a Structured Config class or instance.""" + ) ) @@ -224,11 +285,19 @@ def instantiate_node( recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial + full_key = node._get_full_key(None) + if not isinstance(recursive, bool): - raise TypeError(f"_recursive_ flag must be a bool, got {type(recursive)}") + msg = f"Instantiation: _recursive_ flag must be a bool, got {type(recursive)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) if not isinstance(partial, bool): - raise TypeError(f"_partial_ flag must be a bool, got {type( partial )}") + msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}" + if node and full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): @@ -249,7 +318,7 @@ def instantiate_node( elif OmegaConf.is_dict(node): exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) if _is_target(node): - _target_ = _resolve_target(node.get(_Keys.TARGET)) + _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) kwargs = {} for key, value in node.items(): if key not in exclude_keys: @@ -259,7 +328,7 @@ def instantiate_node( ) kwargs[key] = _convert_node(value, convert) - return _call_target(_target_, partial, *args, **kwargs) + return _call_target(_target_, partial, args, kwargs, full_key) else: # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( diff --git a/hydra/_internal/utils.py b/hydra/_internal/utils.py index 18f556bacdb..c03a14295a4 100644 --- a/hydra/_internal/utils.py +++ b/hydra/_internal/utils.py @@ -7,8 +7,8 @@ from dataclasses import dataclass from os.path import dirname, join, normpath, realpath from traceback import print_exc, print_exception -from types import FrameType, TracebackType -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from types import FrameType, ModuleType, TracebackType +from typing import Any, List, Optional, Sequence, Tuple from omegaconf.errors import OmegaConfBaseException @@ -551,7 +551,7 @@ def get_column_widths(matrix: List[List[str]]) -> List[int]: return widths -def _locate(path: str) -> Union[type, Callable[..., Any]]: +def _locate(path: str) -> Any: """ Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc function `locate`, except that it checks for @@ -559,44 +559,49 @@ def _locate(path: str) -> Union[type, Callable[..., Any]]: """ if path == "": raise ImportError("Empty path") - import builtins from importlib import import_module - parts = [part for part in path.split(".") if part] - module = None - for n in reversed(range(len(parts))): + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + part0 = parts[0] + try: + obj = import_module(part0) + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that module '{part0}' is installed?" + ) from exc_import + for m in range(1, len(parts)): + part = parts[m] try: - mod = ".".join(parts[:n]) - module = import_module(mod) - except Exception as e: - if n == 0: - raise ImportError(f"Error loading module '{path}'") from e - continue - if module: - break - if module: - obj = module - else: - obj = builtins - for part in parts[n:]: - mod = mod + "." + part - if not hasattr(obj, part): - try: - import_module(mod) - except Exception as e: - raise ImportError( - f"Encountered error: `{e}` when loading module '{path}'" - ) from e - obj = getattr(obj, part) - if isinstance(obj, type): - obj_type: type = obj - return obj_type - elif callable(obj): - obj_callable: Callable[..., Any] = obj - return obj_callable - else: - # dummy case - raise ValueError(f"Invalid type ({type(obj)}) found for {path}") + obj = getattr(obj, part) + except AttributeError as exc_attr: + parent_dotpath = ".".join(parts[:m]) + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + obj = import_module(mod) + continue + except ModuleNotFoundError as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" + ) from exc_import + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + ) from exc_import + raise ImportError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" + ) from exc_attr + return obj def _get_cls_name(config: Any, pop: bool = True) -> str: diff --git a/hydra/utils.py b/hydra/utils.py index 6583722f144..33c1ac1dc8d 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -22,18 +22,25 @@ def get_class(path: str) -> type: try: cls = _locate(path) if not isinstance(cls, type): - raise ValueError(f"Located non-class in {path} : {type(cls).__name__}") + raise ValueError( + f"Located non-class of type '{type(cls).__name__}'" + + f" while loading '{path}'" + ) return cls except Exception as e: - log.error(f"Error initializing class at {path} : {e}") + log.error(f"Error initializing class at {path}: {e}") raise e def get_method(path: str) -> Callable[..., Any]: try: - cl = _locate(path) - if not callable(cl): - raise ValueError(f"Non callable object located : {type(cl).__name__}") + obj = _locate(path) + if not callable(obj): + raise ValueError( + f"Located non-callable of type '{type(obj).__name__}'" + + f" while loading '{path}'" + ) + cl: Callable[..., Any] = obj return cl except Exception as e: log.error(f"Error getting callable at {path} : {e}") diff --git a/news/1950.feature b/news/1950.feature new file mode 100644 index 00000000000..b57536960e9 --- /dev/null +++ b/news/1950.feature @@ -0,0 +1 @@ +The `instantiate` API now accepts `ListConfig`/`list`-type config as top-level input. diff --git a/news/2099.feature b/news/2099.feature new file mode 100644 index 00000000000..293e325303e --- /dev/null +++ b/news/2099.feature @@ -0,0 +1 @@ +Improve error messages raised in case of instantiation failure. diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index 7517fe69e28..4d6653c4990 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -3,11 +3,14 @@ import collections.abc from dataclasses import dataclass from functools import partial -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, NoReturn, Optional, Tuple from omegaconf import MISSING, DictConfig, ListConfig from hydra.types import TargetConf +from tests.instantiate.module_shadowed_by_function import a_function + +module_shadowed_by_function = a_function def _convert_type(obj: Any) -> Any: @@ -72,6 +75,16 @@ def module_function(x: int) -> int: return x +class ExceptionTakingNoArgument(Exception): + def __init__(self) -> None: + """Init method taking only one argument (self)""" + super().__init__("Err message") + + +def raise_exception_taking_no_argument() -> NoReturn: + raise ExceptionTakingNoArgument() + + @dataclass class AClass: a: Any diff --git a/tests/instantiate/import_error.py b/tests/instantiate/import_error.py new file mode 100644 index 00000000000..816f9acb238 --- /dev/null +++ b/tests/instantiate/import_error.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +assert False diff --git a/tests/instantiate/module_shadowed_by_function.py b/tests/instantiate/module_shadowed_by_function.py new file mode 100644 index 00000000000..2930b2e0d0e --- /dev/null +++ b/tests/instantiate/module_shadowed_by_function.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +def a_function() -> None: + pass diff --git a/tests/instantiate/test_helpers.py b/tests/instantiate/test_helpers.py index e55e0066b52..37d40784369 100644 --- a/tests/instantiate/test_helpers.py +++ b/tests/instantiate/test_helpers.py @@ -1,12 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import datetime import re +from textwrap import dedent from typing import Any from _pytest.python_api import RaisesContext, raises -from pytest import mark +from pytest import mark, param from hydra._internal.utils import _locate -from hydra.utils import get_class +from hydra.utils import get_class, get_method from tests.instantiate import ( AClass, Adam, @@ -16,6 +18,8 @@ Parameters, ) +from .module_shadowed_by_function import a_function + @mark.parametrize( "name,expected", @@ -27,16 +31,122 @@ ("tests.instantiate.NestingClass", NestingClass), ("tests.instantiate.AnotherClass", AnotherClass), ("", raises(ImportError, match=re.escape("Empty path"))), - [ + ( "not_found", - raises(ImportError, match=re.escape("Error loading module 'not_found'")), - ], + raises(ImportError, match=re.escape("Error loading 'not_found'")), + ), ( "tests.instantiate.b.c.Door", raises( ImportError, match=re.escape("No module named 'tests.instantiate.b'") ), ), + param( + "int", + raises( + ImportError, + match=dedent( + r""" + Error loading 'int': + ModuleNotFoundError\("No module named 'int'",?\) + Are you sure that module 'int' is installed\? + """ + ).strip(), + ), + id="int", + ), + param("builtins.int", int, id="builtins_explicit"), + param("builtins.int.from_bytes", int.from_bytes, id="method_of_builtin"), + param( + "builtins.int.not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'builtins\.int\.not_found': + AttributeError\("type object 'int' has no attribute 'not_found'",?\) + Are you sure that 'not_found' is an attribute of 'builtins\.int'\? + """ + ).strip(), + ), + id="builtin_attribute_error", + ), + param( + "datetime", + datetime, + id="top_level_module", + ), + ("tests.instantiate.Adam", Adam), + ("tests.instantiate.Parameters", Parameters), + ("tests.instantiate.AClass", AClass), + param( + "tests.instantiate.AClass.static_method", + AClass.static_method, + id="staticmethod", + ), + param( + "tests.instantiate.AClass.not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'tests\.instantiate\.AClass\.not_found': + AttributeError\("type object 'AClass' has no attribute 'not_found'",?\) + Are you sure that 'not_found' is an attribute of 'tests\.instantiate\.AClass'\? + """ + ).strip(), + ), + id="class_attribute_error", + ), + ("tests.instantiate.ASubclass", ASubclass), + ("tests.instantiate.NestingClass", NestingClass), + ("tests.instantiate.AnotherClass", AnotherClass), + ("tests.instantiate.module_shadowed_by_function", a_function), + param( + "", + raises(ImportError, match=("Empty path")), + id="invalid-path-empty", + ), + param( + "toplevel_not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'toplevel_not_found': + ModuleNotFoundError\("No module named 'toplevel_not_found'",?\) + Are you sure that module 'toplevel_not_found' is installed\? + """ + ).strip(), + ), + id="toplevel_not_found", + ), + param( + "tests.instantiate.b.c.Door", + raises( + ImportError, + match=dedent( + r""" + Error loading 'tests\.instantiate\.b\.c\.Door': + ModuleNotFoundError\("No module named 'tests\.instantiate\.b'",?\) + Are you sure that 'b' is importable from module 'tests\.instantiate'\?""" + ).strip(), + ), + id="nested_not_found", + ), + param( + "tests.instantiate.import_error", + raises( + ImportError, + match=re.escape( + dedent( + """\ + Error loading 'tests.instantiate.import_error': + AssertionError()""" + ) + ), + ), + ), ], ) def test_locate(name: str, expected: Any) -> None: @@ -47,6 +157,75 @@ def test_locate(name: str, expected: Any) -> None: assert _locate(name) == expected -@mark.parametrize("path,expected_type", [("tests.instantiate.AClass", AClass)]) -def test_get_class(path: str, expected_type: type) -> None: - assert get_class(path) == expected_type +@mark.parametrize( + "name", + [ + param(".", id="invalid-path-period"), + param("..", id="invalid-path-period2"), + param(".mod", id="invalid-path-relative"), + param("..mod", id="invalid-path-relative2"), + param("mod.", id="invalid-path-trailing-dot"), + param("mod..another", id="invalid-path-two-dots"), + ], +) +def test_locate_relative_import_fails(name: str) -> None: + with raises( + ValueError, + match=r"Error loading '.*': invalid dotstring\." + + re.escape("\nRelative imports are not supported."), + ): + _locate(name) + + +@mark.parametrize( + "path,expected", + [ + param("tests.instantiate.AClass", AClass, id="class"), + param("builtins.print", print, id="callable"), + param( + "datetime", + raises( + ValueError, + match="Located non-callable of type 'module' while loading 'datetime'", + ), + id="module-error", + ), + ], +) +def test_get_method(path: str, expected: Any) -> None: + if isinstance(expected, RaisesContext): + with expected: + get_method(path) + else: + assert get_method(path) == expected + + +@mark.parametrize( + "path,expected", + [ + param("tests.instantiate.AClass", AClass, id="class"), + param( + "builtins.print", + raises( + ValueError, + match="Located non-class of type 'builtin_function_or_method'" + + " while loading 'builtins.print'", + ), + id="callable-error", + ), + param( + "datetime", + raises( + ValueError, + match="Located non-class of type 'module' while loading 'datetime'", + ), + id="module-error", + ), + ], +) +def test_get_class(path: str, expected: Any) -> None: + if isinstance(expected, RaisesContext): + with expected: + get_class(path) + else: + assert get_class(path) == expected diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index 90150b81da8..ef120683569 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -4,6 +4,7 @@ import re from dataclasses import dataclass from functools import partial +from textwrap import dedent from typing import Any, Dict, List, Optional, Tuple from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf @@ -11,6 +12,7 @@ import hydra from hydra.errors import InstantiationException +from hydra.test_utils.test_utils import assert_multiline_regex_search from hydra.types import ConvertMode, TargetConf from tests.instantiate import ( AClass, @@ -110,6 +112,26 @@ def config(request: Any, src: Any) -> Any: partial(AClass, a=10, b=20, c=30), id="class+partial", ), + param( + [ + { + "_target_": "tests.instantiate.AClass", + "_partial_": True, + "a": 10, + "b": 20, + "c": 30, + }, + { + "_target_": "tests.instantiate.BClass", + "a": 50, + "b": 60, + "c": 70, + }, + ], + {}, + [partial(AClass, a=10, b=20, c=30), BClass(a=50, b=60, c=70)], + id="list_of_partial_class", + ), param( {"_target_": "tests.instantiate.AClass", "b": 20, "c": 30}, {"a": 10, "d": 40}, @@ -314,6 +336,28 @@ def config(request: Any, src: Any) -> Any: KeywordsInParamsClass(target="foo", partial="bar"), id="keywords_in_params", ), + param([], {}, [], id="list_as_toplevel0"), + param( + [ + { + "_target_": "tests.instantiate.AClass", + "a": 10, + "b": 20, + "c": 30, + "d": 40, + }, + { + "_target_": "tests.instantiate.BClass", + "a": 50, + "b": 60, + "c": 70, + "d": 80, + }, + ], + {}, + [AClass(10, 20, 30, 40), BClass(50, 60, 70, 80)], + id="list_as_toplevel2", + ), ], ) def test_class_instantiate( @@ -325,10 +369,7 @@ def test_class_instantiate( ) -> Any: passthrough["_recursive_"] = recursive obj = instantiate_func(config, **passthrough) - if isinstance(expected, partial): - assert partial_equal(obj, expected) - else: - assert obj == expected + assert partial_equal(obj, expected) def test_none_cases( @@ -458,7 +499,10 @@ def test_class_instantiate_omegaconf_node(instantiate_func: Any, config: Any) -> @mark.parametrize("src", [{"_target_": "tests.instantiate.Adam"}]) def test_instantiate_adam(instantiate_func: Any, config: Any) -> None: - with raises(TypeError): + with raises( + InstantiationException, + match=r"Error in call to target 'tests\.instantiate\.Adam':\nTypeError\(.*\)", + ): # can't instantiate without passing params instantiate_func(config) @@ -499,7 +543,10 @@ def gen() -> Any: def test_instantiate_adam_conf( instantiate_func: Any, is_partial: bool, expected_params: Any ) -> None: - with raises(TypeError): + with raises( + InstantiationException, + match=r"Error in call to target 'tests\.instantiate\.Adam':\nTypeError\(.*\)", + ): # can't instantiate without passing params instantiate_func(AdamConf()) @@ -541,24 +588,82 @@ def test_targetconf_deprecated() -> None: def test_instantiate_bad_adam_conf(instantiate_func: Any, recwarn: Any) -> None: - msg = ( - "Missing value for BadAdamConf._target_. Check that it's properly annotated and overridden." - "\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" + msg = re.escape( + dedent( + """\ + Config has missing value for key `_target_`, cannot instantiate. + Config type: BadAdamConf + Check that the `_target_` key in your dataclass is properly annotated and overridden. + A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'""" + ) ) with raises( InstantiationException, - match=re.escape(msg), + match=msg, ): instantiate_func(BadAdamConf()) def test_instantiate_with_missing_module(instantiate_func: Any) -> None: + _target_ = "tests.instantiate.ClassWithMissingModule" with raises( - ModuleNotFoundError, match=re.escape("No module named 'some_missing_module'") + InstantiationException, + match=dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ModuleNotFoundError\("No module named 'some_missing_module'",?\)""" + ).strip(), ): # can't instantiate when importing a missing module - instantiate_func({"_target_": "tests.instantiate.ClassWithMissingModule"}) + instantiate_func({"_target_": _target_}) + + +def test_instantiate_target_raising_exception_taking_no_arguments( + instantiate_func: Any, +) -> None: + _target_ = "tests.instantiate.raise_exception_taking_no_argument" + with raises( + InstantiationException, + match=( + dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ExceptionTakingNoArgument\('Err message',?\)""" + ).strip() + ), + ): + instantiate_func({}, _target_=_target_) + + +def test_instantiate_target_raising_exception_taking_no_arguments_nested( + instantiate_func: Any, +) -> None: + _target_ = "tests.instantiate.raise_exception_taking_no_argument" + with raises( + InstantiationException, + match=( + dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ExceptionTakingNoArgument\('Err message',?\) + full_key: foo + """ + ).strip() + ), + ): + instantiate_func({"foo": {"_target_": _target_}}) + + +def test_toplevel_list_partial_not_allowed(instantiate_func: Any) -> None: + config = [{"_target_": "tests.instantiate.ClassA", "a": 10, "b": 20, "c": 30}] + with raises( + InstantiationException, + match=re.escape( + "The _partial_ keyword is not compatible with top-level list instantiation" + ), + ): + instantiate_func(config, _partial_=True) @mark.parametrize("is_partial", [True, False]) @@ -1220,26 +1325,65 @@ def test_instantiate_from_class_in_dict( @mark.parametrize( - "config, passthrough, expected", + "config, passthrough, err_msg", [ param( OmegaConf.create({"_target_": AClass}), {}, - AClass(10, 20, 30, 40), - id="class_in_config_dict", + re.escape( + "Expected a callable target, got" + + " '{'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}' of type 'DictConfig'" + ), + id="instantiate-from-dataclass-in-dict-fails", + ), + param( + OmegaConf.create({"foo": {"_target_": AClass}}), + {}, + re.escape( + "Expected a callable target, got" + + " '{'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}' of type 'DictConfig'" + + "\nfull_key: foo" + ), + id="instantiate-from-dataclass-in-dict-fails-nested", ), ], ) def test_instantiate_from_dataclass_in_dict_fails( - instantiate_func: Any, config: Any, passthrough: Any, expected: Any + instantiate_func: Any, config: Any, passthrough: Any, err_msg: str ) -> None: - # not the best error, but it will get the user to check their input config. - msg = "Unsupported target type: DictConfig. value: {'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}" with raises( InstantiationException, - match=re.escape(msg), + match=err_msg, ): - assert instantiate_func(config, **passthrough) == expected + instantiate_func(config, **passthrough) + + +def test_cannot_locate_target(instantiate_func: Any) -> None: + cfg = OmegaConf.create({"foo": {"_target_": "not_found"}}) + with raises( + InstantiationException, + match=re.escape( + dedent( + """\ + Error locating target 'not_found', see chained exception above. + full_key: foo""" + ) + ), + ) as exc_info: + instantiate_func(cfg) + err = exc_info.value + assert hasattr(err, "__cause__") + chained = err.__cause__ + assert isinstance(chained, ImportError) + assert_multiline_regex_search( + dedent( + """\ + Error loading 'not_found': + ModuleNotFoundError\\("No module named 'not_found'",?\\) + Are you sure that module 'not_found' is installed\\?""" + ), + chained.args[0], + ) @mark.parametrize( diff --git a/tests/instantiate/test_positional.py b/tests/instantiate/test_positional.py index e75112ec009..c0525a2bf59 100644 --- a/tests/instantiate/test_positional.py +++ b/tests/instantiate/test_positional.py @@ -1,8 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from textwrap import dedent from typing import Any -from pytest import mark, param +from pytest import mark, param, raises +from hydra.errors import InstantiationException from hydra.utils import instantiate from tests.instantiate import ArgsClass @@ -47,6 +49,43 @@ def test_instantiate_args_kwargs(cfg: Any, expected: Any) -> None: assert instantiate(cfg) == expected +@mark.parametrize( + "cfg, msg", + [ + param( + {"_target_": "tests.instantiate.ArgsClass", "_args_": {"foo": "bar"}}, + dedent( + """\ + Error in collecting args and kwargs for 'tests\\.instantiate\\.ArgsClass': + InstantiationException\\("Unsupported _args_ type: 'DictConfig'\\. value: '{'foo': 'bar'}'",?\\)""" + ), + id="unsupported-args-type", + ), + param( + { + "foo": { + "_target_": "tests.instantiate.ArgsClass", + "_args_": {"foo": "bar"}, + } + }, + dedent( + """\ + Error in collecting args and kwargs for 'tests\\.instantiate\\.ArgsClass': + InstantiationException\\("Unsupported _args_ type: 'DictConfig'\\. value: '{'foo': 'bar'}'",?\\) + full_key: foo""" + ), + id="unsupported-args-type-nested", + ), + ], +) +def test_instantiate_unsupported_args_type(cfg: Any, msg: str) -> None: + with raises( + InstantiationException, + match=msg, + ): + instantiate(cfg) + + @mark.parametrize( ("cfg", "expected"), [