Skip to content

Commit

Permalink
Cherry-picked instantiate improvements (#2120)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored Apr 5, 2022
1 parent 5272a15 commit 0043434
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 108 deletions.
139 changes: 104 additions & 35 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import copy
import functools
import sys
from enum import Enum
from typing import Any, Callable, Sequence, Tuple, Union
from textwrap import dedent
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from omegaconf import OmegaConf, SCMode
from omegaconf._utils import is_structured_config
Expand Down Expand Up @@ -32,7 +32,7 @@ def _is_target(x: Any) -> bool:
return False


def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]:
config_args = kwargs.pop(_Keys.ARGS, ())
output_args = config_args

Expand All @@ -41,16 +41,16 @@ def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]:
output_args = input_args
else:
raise InstantiationException(
f"Unsupported _args_ type: {type(config_args).__name__}. value: {config_args}"
f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'"
)

return output_args, kwargs


def _call_target(_target_: Callable, _partial_: bool, *args, **kwargs) -> Any: # type: ignore
def _call_target(_target_: Callable, _partial_: bool, args, kwargs, full_key) -> Any: # type: ignore
"""Call target (type) with args and kwargs."""
try:
args, kwargs = _extract_pos_args(*args, **kwargs)
args, kwargs = _extract_pos_args(args, kwargs)
# detaching configs from parent.
# At this time, everything is resolved and the parent link can cause
# issues when serializing objects in some scenarios.
Expand All @@ -60,13 +60,35 @@ def _call_target(_target_: Callable, _partial_: bool, *args, **kwargs) -> Any:
for v in kwargs.values():
if OmegaConf.is_config(v):
v._set_parent(None)
if _partial_:
return functools.partial(_target_, *args, **kwargs)
return _target_(*args, **kwargs)
except Exception as e:
raise type(e)(
f"Error instantiating '{_convert_target_to_string(_target_)}' : {e}"
).with_traceback(sys.exc_info()[2])
msg = (
f"Error in collecting args and kwargs for '{_convert_target_to_string(_target_)}':"
+ f"\n{repr(e)}"
)
if full_key:
msg += f"\nfull_key: {full_key}"

raise InstantiationException(msg) from e

if _partial_:
try:
return functools.partial(_target_, *args, **kwargs)
except Exception as e:
msg = (
f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:"
+ f"\n{repr(e)}"
)
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg) from e
else:
try:
return _target_(*args, **kwargs)
except Exception as e:
msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}"
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg) from e


def _convert_target_to_string(t: Any) -> Any:
Expand All @@ -78,40 +100,45 @@ def _convert_target_to_string(t: Any) -> Any:
return t


def _prepare_input_dict(d: Any) -> Any:
def _prepare_input_dict_or_list(d: Union[Dict[Any, Any], List[Any]]) -> Any:
res: Any
if isinstance(d, dict):
res = {}
for k, v in d.items():
if k == "_target_":
v = _convert_target_to_string(d["_target_"])
elif isinstance(v, (dict, list)):
v = _prepare_input_dict(v)
v = _prepare_input_dict_or_list(v)
res[k] = v
elif isinstance(d, list):
res = []
for v in d:
if isinstance(v, (list, dict)):
v = _prepare_input_dict(v)
v = _prepare_input_dict_or_list(v)
res.append(v)
else:
assert False
return res


def _resolve_target(
target: Union[str, type, Callable[..., Any]]
target: Union[str, type, Callable[..., Any]], full_key: str
) -> Union[type, Callable[..., Any]]:
"""Resolve target string, type or callable into type or callable."""
if isinstance(target, str):
return _locate(target)
if isinstance(target, type):
return target
if callable(target):
return target
raise InstantiationException(
f"Unsupported target type: {type(target).__name__}. value: {target}"
)
try:
target = _locate(target)
except Exception as e:
msg = f"Error locating target '{target}', see chained exception above."
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg) from e
if not callable(target):
msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'"
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg)
return target


def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -151,17 +178,23 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(config, TargetConf) and config._target_ == "???":
# Specific check to give a good warning about failure to annotate _target_ as a string.
raise InstantiationException(
f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
dedent(
f"""\
Config has missing value for key `_target_`, cannot instantiate.
Config type: {type(config).__name__}
Check that the `_target_` key in your dataclass is properly annotated and overridden.
A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"""
)
)
# TODO: print full key

if isinstance(config, dict):
config = _prepare_input_dict(config)
if isinstance(config, (dict, list)):
config = _prepare_input_dict_or_list(config)

kwargs = _prepare_input_dict(kwargs)
kwargs = _prepare_input_dict_or_list(kwargs)

# Structured Config always converted first to OmegaConf
if is_structured_config(config) or isinstance(config, dict):
if is_structured_config(config) or isinstance(config, (dict, list)):
config = OmegaConf.structured(config, flags={"allow_objects": True})

if OmegaConf.is_dict(config):
Expand All @@ -182,12 +215,40 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
_convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
_partial_ = config.pop(_Keys.PARTIAL, False)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
config_copy = copy.deepcopy(config)
config_copy._set_flag(
flags=["allow_objects", "struct", "readonly"], values=[True, False, False]
)
config_copy._set_parent(config._get_parent())
config = config_copy

OmegaConf.resolve(config)

_recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
_convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE)
_partial_ = kwargs.pop(_Keys.PARTIAL, False)

if _partial_:
raise InstantiationException(
"The _partial_ keyword is not compatible with top-level list instantiation"
)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
)
else:
raise InstantiationException(
"Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance"
dedent(
f"""\
Cannot instantiate config of type {type(config).__name__}.
Top level config must be an OmegaConf DictConfig/ListConfig object,
a plain dict/list, or a Structured Config class or instance."""
)
)


Expand Down Expand Up @@ -224,11 +285,19 @@ def instantiate_node(
recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive
partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial

full_key = node._get_full_key(None)

if not isinstance(recursive, bool):
raise TypeError(f"_recursive_ flag must be a bool, got {type(recursive)}")
msg = f"Instantiation: _recursive_ flag must be a bool, got {type(recursive)}"
if full_key:
msg += f"\nfull_key: {full_key}"
raise TypeError(msg)

if not isinstance(partial, bool):
raise TypeError(f"_partial_ flag must be a bool, got {type( partial )}")
msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}"
if node and full_key:
msg += f"\nfull_key: {full_key}"
raise TypeError(msg)

# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
Expand All @@ -249,7 +318,7 @@ def instantiate_node(
elif OmegaConf.is_dict(node):
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET))
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
kwargs = {}
for key, value in node.items():
if key not in exclude_keys:
Expand All @@ -259,7 +328,7 @@ def instantiate_node(
)
kwargs[key] = _convert_node(value, convert)

return _call_target(_target_, partial, *args, **kwargs)
return _call_target(_target_, partial, args, kwargs, full_key)
else:
# If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly.
if convert == ConvertMode.ALL or (
Expand Down
81 changes: 43 additions & 38 deletions hydra/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from dataclasses import dataclass
from os.path import dirname, join, normpath, realpath
from traceback import print_exc, print_exception
from types import FrameType, TracebackType
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from types import FrameType, ModuleType, TracebackType
from typing import Any, List, Optional, Sequence, Tuple

from omegaconf.errors import OmegaConfBaseException

Expand Down Expand Up @@ -551,52 +551,57 @@ def get_column_widths(matrix: List[List[str]]) -> List[int]:
return widths


def _locate(path: str) -> Union[type, Callable[..., Any]]:
def _locate(path: str) -> Any:
"""
Locate an object by name or dotted path, importing as necessary.
This is similar to the pydoc function `locate`, except that it checks for
the module from the given path from back to front.
"""
if path == "":
raise ImportError("Empty path")
import builtins
from importlib import import_module

parts = [part for part in path.split(".") if part]
module = None
for n in reversed(range(len(parts))):
parts = [part for part in path.split(".")]
for part in parts:
if not len(part):
raise ValueError(
f"Error loading '{path}': invalid dotstring."
+ "\nRelative imports are not supported."
)
assert len(parts) > 0
part0 = parts[0]
try:
obj = import_module(part0)
except Exception as exc_import:
raise ImportError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that module '{part0}' is installed?"
) from exc_import
for m in range(1, len(parts)):
part = parts[m]
try:
mod = ".".join(parts[:n])
module = import_module(mod)
except Exception as e:
if n == 0:
raise ImportError(f"Error loading module '{path}'") from e
continue
if module:
break
if module:
obj = module
else:
obj = builtins
for part in parts[n:]:
mod = mod + "." + part
if not hasattr(obj, part):
try:
import_module(mod)
except Exception as e:
raise ImportError(
f"Encountered error: `{e}` when loading module '{path}'"
) from e
obj = getattr(obj, part)
if isinstance(obj, type):
obj_type: type = obj
return obj_type
elif callable(obj):
obj_callable: Callable[..., Any] = obj
return obj_callable
else:
# dummy case
raise ValueError(f"Invalid type ({type(obj)}) found for {path}")
obj = getattr(obj, part)
except AttributeError as exc_attr:
parent_dotpath = ".".join(parts[:m])
if isinstance(obj, ModuleType):
mod = ".".join(parts[: m + 1])
try:
obj = import_module(mod)
continue
except ModuleNotFoundError as exc_import:
raise ImportError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?"
) from exc_import
except Exception as exc_import:
raise ImportError(
f"Error loading '{path}':\n{repr(exc_import)}"
) from exc_import
raise ImportError(
f"Error loading '{path}':\n{repr(exc_attr)}"
+ f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?"
) from exc_attr
return obj


def _get_cls_name(config: Any, pop: bool = True) -> str:
Expand Down
17 changes: 12 additions & 5 deletions hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,25 @@ def get_class(path: str) -> type:
try:
cls = _locate(path)
if not isinstance(cls, type):
raise ValueError(f"Located non-class in {path} : {type(cls).__name__}")
raise ValueError(
f"Located non-class of type '{type(cls).__name__}'"
+ f" while loading '{path}'"
)
return cls
except Exception as e:
log.error(f"Error initializing class at {path} : {e}")
log.error(f"Error initializing class at {path}: {e}")
raise e


def get_method(path: str) -> Callable[..., Any]:
try:
cl = _locate(path)
if not callable(cl):
raise ValueError(f"Non callable object located : {type(cl).__name__}")
obj = _locate(path)
if not callable(obj):
raise ValueError(
f"Located non-callable of type '{type(obj).__name__}'"
+ f" while loading '{path}'"
)
cl: Callable[..., Any] = obj
return cl
except Exception as e:
log.error(f"Error getting callable at {path} : {e}")
Expand Down
1 change: 1 addition & 0 deletions news/1950.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The `instantiate` API now accepts `ListConfig`/`list`-type config as top-level input.
1 change: 1 addition & 0 deletions news/2099.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve error messages raised in case of instantiation failure.
Loading

0 comments on commit 0043434

Please sign in to comment.