Skip to content

Commit

Permalink
Fix crash with "interpolation-like" strings from interpolations (#709)
Browse files Browse the repository at this point in the history
Fix crash with "interpolation-like" strings from interpolations

This commit introduces a new node type `InterpolationResultNode` that systematically wraps interpolation results that either (a) are not already nodes, or (b) need to be converted.

Fixes #666
  • Loading branch information
odelalleau authored May 12, 2021
1 parent d67f69d commit fe6b207
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 57 deletions.
46 changes: 2 additions & 44 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_is_missing_value,
format_and_raise,
get_value_kind,
is_primitive_type,
split_key,
)
from .errors import (
Expand Down Expand Up @@ -504,7 +503,7 @@ def _validate_and_convert_interpolation_result(
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import AnyNode, ValueNode
from .nodes import AnyNode, InterpolationResultNode, ValueNode

# If the output is not a Node already (e.g., because it is the output of a
# custom resolver), then we will need to wrap it within a Node.
Expand Down Expand Up @@ -533,52 +532,11 @@ def _validate_and_convert_interpolation_result(
resolved = conv_value

if must_wrap:
return self._wrap_interpolation_result(
parent=parent,
value=value,
key=key,
resolved=resolved,
throw_on_resolution_failure=throw_on_resolution_failure,
)
return InterpolationResultNode(value=resolved, key=key, parent=parent)
else:
assert isinstance(resolved, Node)
return resolved

def _wrap_interpolation_result(
self,
parent: Optional["Container"],
value: Node,
key: Any,
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .basecontainer import BaseContainer
from .nodes import AnyNode
from .omegaconf import _node_wrap

assert parent is None or isinstance(parent, BaseContainer)

if is_primitive_type(type(resolved)):
# Primitive types get wrapped using `_node_wrap()`, ensuring value is
# validated and potentially converted.
wrapped = _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
)
else:
# Other objects get wrapped into an `AnyNode` with `allow_objects` set
# to True.
wrapped = AnyNode(
value=resolved, key=key, parent=None, flags={"allow_objects": True}
)
wrapped._set_parent(parent)

return wrapped

def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None:
parent: Optional[Node] = node
while parent is not None:
Expand Down
43 changes: 42 additions & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, parent: Optional[Container], value: Any, metadata: Metadata):

super().__init__(parent=parent, metadata=metadata)
with read_write(self):
self._set_value(value)
self._set_value(value) # lgtm [py/init-calls-subclass]

def _value(self) -> Any:
return self._val
Expand Down Expand Up @@ -390,3 +390,44 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode":
res = EnumNode(enum_type=self.enum_type)
self._deepcopy_impl(res, memo)
return res


class InterpolationResultNode(ValueNode):
"""
Special node type, used to wrap interpolation results.
"""

def __init__(
self,
value: Any,
key: Any = None,
parent: Optional[Container] = None,
flags: Optional[Dict[str, bool]] = None,
):
super().__init__(
parent=parent,
value=value,
metadata=Metadata(
ref_type=Any, object_type=None, key=key, optional=True, flags=flags
),
)
# In general we should not try to write into interpolation results.
if flags is None or "readonly" not in flags:
self._set_flag("readonly", True)

def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot set value of read-only config node")
self._val = self.validate_and_convert(value)

def _validate_and_convert_impl(self, value: Any) -> Any:
# Interpolation results may be anything.
return value

def __deepcopy__(self, memo: Dict[int, Any]) -> "InterpolationResultNode":
# Currently there should be no need to deep-copy such nodes.
raise NotImplementedError

def _is_interpolation(self) -> bool:
# The result of an interpolation cannot be itself an interpolation.
return False
10 changes: 5 additions & 5 deletions tests/interpolation/test_custom_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest import mark, param, raises, warns

from omegaconf import OmegaConf, Resolver
from omegaconf.nodes import AnyNode
from omegaconf.nodes import InterpolationResultNode
from tests.interpolation import dereference_node


Expand Down Expand Up @@ -355,8 +355,8 @@ def test_resolver_output_dict(restore_resolvers: Any, readonly: bool) -> None:
assert isinstance(c.x, dict)
assert c.x == some_dict
x_node = dereference_node(c, "x")
assert isinstance(x_node, AnyNode)
assert x_node._get_flag("allow_objects")
assert isinstance(x_node, InterpolationResultNode)
assert x_node._get_flag("readonly")


@mark.parametrize("readonly", [True, False])
Expand All @@ -378,8 +378,8 @@ def test_resolver_output_plain_dict_list(
assert c.x == data

x_node = dereference_node(c, "x")
assert isinstance(x_node, AnyNode)
assert x_node._get_flag("allow_objects")
assert isinstance(x_node, InterpolationResultNode)
assert x_node._get_flag("readonly")


def test_register_cached_resolver_with_keyword_unsupported() -> None:
Expand Down
53 changes: 47 additions & 6 deletions tests/interpolation/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@
from omegaconf import (
II,
SI,
AnyNode,
Container,
DictConfig,
IntegerNode,
ListConfig,
Node,
OmegaConf,
StringNode,
ValidationError,
)
from omegaconf._utils import _ensure_container
from omegaconf.errors import InterpolationKeyError
from omegaconf.errors import InterpolationResolutionError
from omegaconf.errors import InterpolationResolutionError as IRE
from omegaconf.errors import InterpolationValidationError
from omegaconf.errors import InterpolationValidationError, ReadonlyConfigError
from omegaconf.nodes import InterpolationResultNode
from tests import MissingDict, MissingList, StructuredWithMissing, SubscriptedList, User
from tests.interpolation import dereference_node

# file deepcode ignore CopyPasteError:
# The above comment is a statement to stop DeepCode from raising a warning on
Expand Down Expand Up @@ -257,7 +261,7 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
User(name="Bond", age=SI("${cast:int,'7'}")),
"age",
7,
IntegerNode,
InterpolationResultNode,
id="expected_type",
),
param(
Expand All @@ -266,7 +270,7 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
User(name="Bond", age=SI("${cast:int,${drop_last:${drop_last:7xx}}}")),
"age",
7,
IntegerNode,
InterpolationResultNode,
id="intermediate_type_mismatch_ok",
),
param(
Expand All @@ -275,20 +279,20 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
User(name="Bond", age=SI("${cast:str,'7'}")),
"age",
7,
IntegerNode,
InterpolationResultNode,
id="convert_str_to_int",
),
param(
MissingList(list=SI("${oc.create:[a, b, c]}")),
"list",
["a", "b", "c"],
ListConfig(["a", "b", "c"]),
ListConfig,
id="list_str",
),
param(
MissingDict(dict=SI("${oc.create:{key1: val1, key2: val2}}")),
"dict",
{"key1": "val1", "key2": "val2"},
DictConfig({"key1": "val1", "key2": "val2"}),
DictConfig,
id="dict_str",
),
Expand All @@ -310,6 +314,7 @@ def drop_last(s: str) -> str:

val = cfg[key]
assert val == expected_value
assert type(val) is type(expected_value)

node = cfg._get_node(key)
assert isinstance(node, Node)
Expand Down Expand Up @@ -463,3 +468,39 @@ def test_circular_interpolation(cfg: Any, key: str, expected: Any) -> None:
OmegaConf.select(cfg, key)
else:
assert OmegaConf.select(cfg, key) == expected


@mark.parametrize(
"node_type",
[
param(lambda x: x, id="untyped"),
param(AnyNode, id="any"),
param(StringNode, id="str"),
],
)
@mark.parametrize(
("value", "expected"),
[
param(r"\${foo}", "${foo}", id="escaped_interpolation_1"),
param(r"\${foo", "${foo", id="escaped_interpolation_2"),
param(r"$${y1}", "${foo}", id="string_interpolation_1"),
param(r"$${y2}", "${foo", id="string_interpolation_2"),
# This passes to `oc.decode` the string with characters: '\${foo}' which
# is then resolved into: ${foo}
param(r"${oc.decode:'\'\\\${foo}\''}", "${foo}", id="resolver_1"),
param(r"${oc.decode:'\'\\\${foo\''}", "${foo", id="resolver_2"),
],
)
def test_interpolation_like_result_is_not_an_interpolation(
node_type: Any, value: str, expected: str
) -> None:
cfg = OmegaConf.create({"x": node_type(value), "y1": "{foo}", "y2": "{foo"})
assert cfg.x == expected

# Check that the resulting node is not considered to be an interpolation.
resolved_node = dereference_node(cfg, "x")
assert not resolved_node._is_interpolation()

# Check that the resulting node is read-only.
with raises(ReadonlyConfigError):
resolved_node._set_value("foo")
45 changes: 44 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
UnsupportedValueType,
ValidationError,
)
from omegaconf.nodes import InterpolationResultNode
from tests import Color, IllegalType, User


Expand Down Expand Up @@ -499,14 +500,29 @@ def test_deepcopy(obj: Any) -> None:
True,
),
(EnumNode(enum_type=Enum1, value=Enum1.BAR), Enum1.BAR, True),
(InterpolationResultNode("foo"), "foo", True),
(InterpolationResultNode("${foo}"), "${foo}", True),
(InterpolationResultNode("${foo"), "${foo", True),
(InterpolationResultNode(None), None, True),
(InterpolationResultNode(1), 1, True),
(InterpolationResultNode(1.0), 1.0, True),
(InterpolationResultNode(True), True, True),
(InterpolationResultNode(Color.RED), Color.RED, True),
(InterpolationResultNode({"a": 0, "b": 1}), {"a": 0, "b": 1}, True),
(InterpolationResultNode([0, None, True]), [0, None, True], True),
(InterpolationResultNode("foo"), 100, False),
(InterpolationResultNode(100), "foo", False),
],
)
def test_eq(node: ValueNode, value: Any, expected: Any) -> None:
assert (node == value) == expected
assert (node != value) != expected
assert (value == node) == expected
assert (value != node) != expected
assert (node.__hash__() == value.__hash__()) == expected

# Check hash except for unhashable types (dict/list).
if not isinstance(value, (dict, list)):
assert (node.__hash__() == value.__hash__()) == expected


@mark.parametrize("value", [1, 3.14, True, None, Enum1.FOO])
Expand Down Expand Up @@ -616,10 +632,37 @@ def test_dereference_interpolation_to_missing() -> None:
functools.partial(EnumNode, enum_type=Color),
FloatNode,
IntegerNode,
InterpolationResultNode,
StringNode,
],
)
def test_set_flags_in_init(type_: Any, flags: Dict[str, bool]) -> None:
node = type_(value=None, flags=flags)
for f, v in flags.items():
assert node._get_flag(f) is v


@mark.parametrize(
"flags",
[
None,
{"flag": True},
{"flag": False},
{"readonly": True},
{"readonly": False},
{"flag1": True, "flag2": False, "readonly": False},
{"flag1": False, "flag2": True, "readonly": True},
],
)
def test_interpolation_result_readonly(flags: Any) -> None:
readonly = None if flags is None else flags.get("readonly")
expected = [] if flags is None else list(flags.items())
node = InterpolationResultNode("foo", flags=flags)

# Check that flags are set to their desired value.
for k, v in expected:
assert node._get_node_flag(k) is v

# If no value was provided for the "readonly" flag, it should be set.
if readonly is None:
assert node._get_node_flag("readonly")

0 comments on commit fe6b207

Please sign in to comment.