Skip to content

Commit

Permalink
list/tuple subclass: behavior same as 2.1_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 committed May 27, 2022
1 parent 0af19e6 commit 9e49984
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 4 deletions.
2 changes: 1 addition & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _node_wrap(
element_type=element_type,
)
elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
is_primitive_list(value) and ref_type is Any
type(value) in (list, tuple) and ref_type is Any
):
element_type = get_list_element_type(ref_type)
node = ListConfig(
Expand Down
26 changes: 25 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import attr
from pytest import warns
Expand Down Expand Up @@ -226,6 +226,11 @@ class SubscriptedListOpt:
list_opt: List[Optional[int]] = field(default_factory=lambda: [1, 2, None])


@dataclass
class ListOfAny:
list: List[Any]


@dataclass
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
Expand All @@ -250,6 +255,11 @@ class SubscriptedDictOpt:
)


@dataclass
class DictOfAny:
dict: Dict[Any, Any]


@dataclass
class InterpolationList:
list: List[float] = II("optimization.lr")
Expand All @@ -265,6 +275,20 @@ class Str2Int(Dict[str, int]):
pass


class DictSubclass(Dict[Any, Any]):
pass


class ListSubclass(List[Any]):
pass


class Shape(NamedTuple):
channels: int
height: int
width: int


@dataclass
class OptTuple:
x: Optional[Tuple[int, ...]] = None
Expand Down
57 changes: 55 additions & 2 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform
import re
import sys
from collections.abc import Sequence
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Optional
Expand All @@ -10,8 +11,18 @@
from pytest import mark, param, raises

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.errors import UnsupportedValueType
from tests import ConcretePlugin, IllegalType, NonCopyableIllegalType, Plugin
from omegaconf.errors import UnsupportedValueType, ValidationError
from tests import (
ConcretePlugin,
DictOfAny,
DictSubclass,
IllegalType,
ListOfAny,
ListSubclass,
NonCopyableIllegalType,
Plugin,
Shape,
)


@mark.parametrize(
Expand Down Expand Up @@ -112,6 +123,48 @@ def test_create_allow_objects_non_copyable(input_: Any) -> None:
assert cfg == input_


@mark.parametrize(
"input_",
[
param(Shape(10, 2, 3), id="shape"),
param(ListSubclass((1, 2, 3)), id="list_subclass"),
param(DictSubclass({"key": "value"}), id="dict_subclass"),
],
)
class TestCreationWithCustomClass:
def test_top_level(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = OmegaConf.create(input_) # type: ignore
assert isinstance(cfg, ListConfig)
else:
with raises(ValidationError):
OmegaConf.create(input_)

def test_nested(self, input_: Any) -> None:
with raises(UnsupportedValueType):
OmegaConf.create({"foo": input_})

def test_nested_allow_objects(self, input_: Any) -> None:
cfg = OmegaConf.create({"foo": input_}, flags={"allow_objects": True})
assert isinstance(cfg.foo, type(input_))

def test_structured_conf(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = OmegaConf.structured(ListOfAny(input_)) # type: ignore
assert isinstance(cfg.list, ListConfig)
else:
cfg = OmegaConf.structured(DictOfAny(input_))
assert isinstance(cfg.dict, DictConfig)

def test_direct_creation_of_listconfig_or_dictconfig(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = ListConfig(input_) # type: ignore
assert isinstance(cfg, ListConfig)
else:
cfg = DictConfig(input_) # type: ignore
assert isinstance(cfg, DictConfig)


@mark.parametrize(
"input_",
[
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
get_list_element_type,
is_dict_annotation,
is_list_annotation,
is_primitive_dict,
is_primitive_list,
is_supported_union_annotation,
is_tuple_annotation,
is_union_annotation,
Expand All @@ -42,8 +44,11 @@
Color,
ConcretePlugin,
Dataframe,
DictSubclass,
IllegalType,
ListSubclass,
Plugin,
Shape,
Str2Int,
UnionAnnotations,
User,
Expand Down Expand Up @@ -662,6 +667,54 @@ def test_type_str_nonetype(type_: Any, expected: str) -> None:
assert _utils.type_str(type_) == expected


@mark.parametrize(
"obj, expected",
[
param([], True, id="list"),
param([1], True, id="list1"),
param((), True, id="tuple"),
param((1,), True, id="tuple1"),
param({}, False, id="dict"),
param(ListSubclass(), True, id="list_subclass"),
param(Shape(10, 2, 3), True, id="namedtuple"),
],
)
def test_is_primitive_list(obj: Any, expected: bool) -> None:
assert is_primitive_list(obj) == expected


@mark.parametrize(
"obj, expected",
[
param({}, True, id="dict"),
param({1: 2}, True, id="dict1"),
param([], False, id="list"),
param((), False, id="tuple"),
],
)
def test_is_primitive_dict(obj: Any, expected: bool) -> None:
assert is_primitive_dict(obj) == expected


@mark.parametrize(
"obj",
[
param(DictConfig({}), id="dictconfig"),
param(ListConfig([]), id="listconfig"),
param(DictSubclass(), id="dict_subclass"),
param(Str2Int(), id="dict_subclass_dataclass"),
param(User, id="user"),
param(User("bond", 7), id="user"),
],
)
class TestIsPrimitiveContainerNegative:
def test_is_primitive_list(self, obj: Any) -> None:
assert not is_primitive_list(obj)

def test_is_primitive_dict(self, obj: Any) -> None:
assert not is_primitive_dict(obj)


@mark.parametrize(
"type_, expected",
[
Expand Down

0 comments on commit 9e49984

Please sign in to comment.