Skip to content

Commit ec2cc83

Browse files
authored
Instantiation mode "partial" to "callable". Return the _target_ component as-is when in _mode_="callable" and no kwargs are specified (#7413)
### Description A `_target_` component with `_mode_="partial"` will still be wrapped in `functools.partial` even when no kwargs are passed: `functool.partial(component)`. In such cases, the component can just be returned as-is. If you agree with this, I will add tests for it. Thank you! ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Ibrahim Hadzic <[email protected]>
1 parent 718e4be commit ec2cc83

File tree

6 files changed

+16
-14
lines changed

6 files changed

+16
-14
lines changed

docs/source/config_syntax.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
168168
- `_mode_` specifies the operating mode when the component is instantiated or the callable is called.
169169
it currently supports the following values:
170170
- `"default"` (default) -- return the return value of ``_target_(**kwargs)``
171-
- `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often
172-
useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to
173-
call it with additional arguments later).
171+
- `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a
172+
partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function
173+
that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call
174+
it with additional arguments later.
174175
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
175176
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).
176177

monai/bundle/config_item.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable):
181181
- ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``:
182182
183183
- ``"default"``: returns ``component(**kwargs)``
184-
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
184+
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
185185
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
186186
187187
Other fields in the config content are input arguments to the python module.

monai/utils/enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ class CompInitMode(StrEnum):
411411
"""
412412

413413
DEFAULT = "default"
414-
PARTIAL = "partial"
414+
CALLABLE = "callable"
415415
DEBUG = "debug"
416416

417417

monai/utils/module.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
231231
232232
Args:
233233
__path: if a string is provided, it's interpreted as the full path of the target class or function component.
234-
If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned.
234+
If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``.
235+
For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,
236+
as ``functools.partial(__path, **kwargs)`` for future invoking.
237+
235238
__mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:
236239
237240
- ``"default"``: returns ``component(**kwargs)``
238-
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
241+
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
239242
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
240243
241244
kwargs: keyword arguments to the callable represented by ``__path``.
@@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
259262
return component
260263
if m == CompInitMode.DEFAULT:
261264
return component(**kwargs)
262-
if m == CompInitMode.PARTIAL:
263-
return partial(component, **kwargs)
265+
if m == CompInitMode.CALLABLE:
266+
return partial(component, **kwargs) if kwargs else component
264267
if m == CompInitMode.DEBUG:
265268
warnings.warn(
266269
f"\n\npdb: instantiating component={component}, mode={m}\n"

tests/test_config_item.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
3838
# test non-monai modules and excludes
3939
TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam]
40-
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial]
40+
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial]
4141
# test args contains "name" field
4242
TEST_CASE_8 = [
4343
{"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},

tests/test_config_parser.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def case_pdb_inst(sarg=None):
7272

7373

7474
class TestClass:
75-
7675
@staticmethod
7776
def compute(a, b, func=lambda x, y: x + y):
7877
return func(a, b)
@@ -127,7 +126,6 @@ def __call__(self, a, b):
127126

128127

129128
class TestConfigParser(unittest.TestCase):
130-
131129
def test_config_content(self):
132130
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
133131
parser = ConfigParser(config=test_config)
@@ -183,7 +181,7 @@ def test_function(self, config):
183181
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
184182
for id in config:
185183
if id in ("compute", "cls_compute"):
186-
parser[f"{id}#_mode_"] = "partial"
184+
parser[f"{id}#_mode_"] = "callable"
187185
func = parser.get_parsed_content(id=id)
188186
self.assertTrue(id in parser.ref_resolver.resolved_content)
189187
if id == "error_func":
@@ -279,7 +277,7 @@ def test_lambda_reference(self):
279277

280278
def test_non_str_target(self):
281279
configs = {
282-
"fwd": {"_target_": "[email protected]", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"},
280+
"fwd": {"_target_": "[email protected]", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"},
283281
"model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2},
284282
}
285283
self.assertTrue(callable(ConfigParser(config=configs).fwd))

0 commit comments

Comments
 (0)