From 9e49984dd65542801ea38982bf2b69b64cb9fc96 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 26 May 2022 04:05:26 -0500 Subject: [PATCH] list/tuple subclass: behavior same as 2.1_branch --- omegaconf/omegaconf.py | 2 +- tests/__init__.py | 26 ++++++++++++++++++- tests/test_create.py | 57 ++++++++++++++++++++++++++++++++++++++++-- tests/test_utils.py | 53 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 4 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 3ec0202a0..7bdabaf59 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -993,7 +993,7 @@ def _node_wrap( element_type=element_type, ) elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or ( - is_primitive_list(value) and ref_type is Any + type(value) in (list, tuple) and ref_type is Any ): element_type = get_list_element_type(ref_type) node = ListConfig( diff --git a/tests/__init__.py b/tests/__init__.py index 79522d110..489d94708 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import attr from pytest import warns @@ -226,6 +226,11 @@ class SubscriptedListOpt: list_opt: List[Optional[int]] = field(default_factory=lambda: [1, 2, None]) +@dataclass +class ListOfAny: + list: List[Any] + + @dataclass class UntypedDict: dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore @@ -250,6 +255,11 @@ class SubscriptedDictOpt: ) +@dataclass +class DictOfAny: + dict: Dict[Any, Any] + + @dataclass class InterpolationList: list: List[float] = II("optimization.lr") @@ -265,6 +275,20 @@ class Str2Int(Dict[str, int]): pass +class DictSubclass(Dict[Any, Any]): + pass + + +class ListSubclass(List[Any]): + pass + + +class Shape(NamedTuple): + channels: int + height: int + width: int + + @dataclass class OptTuple: x: Optional[Tuple[int, ...]] = None diff --git a/tests/test_create.py b/tests/test_create.py index d89e59018..ada960e8a 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,6 +2,7 @@ import platform import re import sys +from collections.abc import Sequence from pathlib import Path from textwrap import dedent from typing import Any, Dict, List, Optional @@ -10,8 +11,18 @@ from pytest import mark, param, raises from omegaconf import DictConfig, ListConfig, OmegaConf -from omegaconf.errors import UnsupportedValueType -from tests import ConcretePlugin, IllegalType, NonCopyableIllegalType, Plugin +from omegaconf.errors import UnsupportedValueType, ValidationError +from tests import ( + ConcretePlugin, + DictOfAny, + DictSubclass, + IllegalType, + ListOfAny, + ListSubclass, + NonCopyableIllegalType, + Plugin, + Shape, +) @mark.parametrize( @@ -112,6 +123,48 @@ def test_create_allow_objects_non_copyable(input_: Any) -> None: assert cfg == input_ +@mark.parametrize( + "input_", + [ + param(Shape(10, 2, 3), id="shape"), + param(ListSubclass((1, 2, 3)), id="list_subclass"), + param(DictSubclass({"key": "value"}), id="dict_subclass"), + ], +) +class TestCreationWithCustomClass: + def test_top_level(self, input_: Any) -> None: + if isinstance(input_, Sequence): + cfg = OmegaConf.create(input_) # type: ignore + assert isinstance(cfg, ListConfig) + else: + with raises(ValidationError): + OmegaConf.create(input_) + + def test_nested(self, input_: Any) -> None: + with raises(UnsupportedValueType): + OmegaConf.create({"foo": input_}) + + def test_nested_allow_objects(self, input_: Any) -> None: + cfg = OmegaConf.create({"foo": input_}, flags={"allow_objects": True}) + assert isinstance(cfg.foo, type(input_)) + + def test_structured_conf(self, input_: Any) -> None: + if isinstance(input_, Sequence): + cfg = OmegaConf.structured(ListOfAny(input_)) # type: ignore + assert isinstance(cfg.list, ListConfig) + else: + cfg = OmegaConf.structured(DictOfAny(input_)) + assert isinstance(cfg.dict, DictConfig) + + def test_direct_creation_of_listconfig_or_dictconfig(self, input_: Any) -> None: + if isinstance(input_, Sequence): + cfg = ListConfig(input_) # type: ignore + assert isinstance(cfg, ListConfig) + else: + cfg = DictConfig(input_) # type: ignore + assert isinstance(cfg, DictConfig) + + @mark.parametrize( "input_", [ diff --git a/tests/test_utils.py b/tests/test_utils.py index f859c16e3..62ca14b52 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,8 @@ get_list_element_type, is_dict_annotation, is_list_annotation, + is_primitive_dict, + is_primitive_list, is_supported_union_annotation, is_tuple_annotation, is_union_annotation, @@ -42,8 +44,11 @@ Color, ConcretePlugin, Dataframe, + DictSubclass, IllegalType, + ListSubclass, Plugin, + Shape, Str2Int, UnionAnnotations, User, @@ -662,6 +667,54 @@ def test_type_str_nonetype(type_: Any, expected: str) -> None: assert _utils.type_str(type_) == expected +@mark.parametrize( + "obj, expected", + [ + param([], True, id="list"), + param([1], True, id="list1"), + param((), True, id="tuple"), + param((1,), True, id="tuple1"), + param({}, False, id="dict"), + param(ListSubclass(), True, id="list_subclass"), + param(Shape(10, 2, 3), True, id="namedtuple"), + ], +) +def test_is_primitive_list(obj: Any, expected: bool) -> None: + assert is_primitive_list(obj) == expected + + +@mark.parametrize( + "obj, expected", + [ + param({}, True, id="dict"), + param({1: 2}, True, id="dict1"), + param([], False, id="list"), + param((), False, id="tuple"), + ], +) +def test_is_primitive_dict(obj: Any, expected: bool) -> None: + assert is_primitive_dict(obj) == expected + + +@mark.parametrize( + "obj", + [ + param(DictConfig({}), id="dictconfig"), + param(ListConfig([]), id="listconfig"), + param(DictSubclass(), id="dict_subclass"), + param(Str2Int(), id="dict_subclass_dataclass"), + param(User, id="user"), + param(User("bond", 7), id="user"), + ], +) +class TestIsPrimitiveContainerNegative: + def test_is_primitive_list(self, obj: Any) -> None: + assert not is_primitive_list(obj) + + def test_is_primitive_dict(self, obj: Any) -> None: + assert not is_primitive_dict(obj) + + @mark.parametrize( "type_, expected", [