Skip to content

Commit

Permalink
🐛 Notify users about dependency condition (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
MiWeiss authored Feb 3, 2023
1 parent 5fe3b9a commit 1029d9d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import inspect
import unittest
from unittest import TestCase

import tensorflow as tf

from uncertainty_wizard.internal_utils.tf_version_resolver import (
current_tf_version_is_older_than,
)


class TestExperimentalAPIAreAvailable(TestCase):
"""
Expand All @@ -21,6 +26,9 @@ def test_list_physical_devices(self):
self.assertTrue("device_type" in parameters)
self.assertEqual(1, len(parameters))

@unittest.skipIf(
not current_tf_version_is_older_than("2.10.0"), "Known to fail for tf >= 2.10.0"
)
def test_virtual_device_configuration(self):
self.assertTrue("VirtualDeviceConfiguration" in dir(tf.config.experimental))
parameters = inspect.signature(
Expand Down
14 changes: 13 additions & 1 deletion uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import tensorflow as tf

from uncertainty_wizard.internal_utils.tf_version_resolver import (
current_tf_version_is_older_than,
)
from uncertainty_wizard.models.ensemble_utils._save_config import SaveConfig

global number_of_tasks_in_this_process
Expand Down Expand Up @@ -35,7 +38,7 @@ def __init__(self, model_id: int, varargs: dict = None):
it will have to generate a context.
Later, to make it easier for custom child classes of EnsembleContextManager,
a (now still empty) varargs is also passed which may be populated with more information
in future versions of uncertainty_wizard.
in future s of uncertainty_wizard.
"""
self.ensemble_id = (model_id,)
self.varargs = varargs
Expand Down Expand Up @@ -225,6 +228,15 @@ class DeviceAllocatorContextManager(EnsembleContextManager, abc.ABC):
the abstract methods.
"""

def __init__(self):
super().__init__()
if not current_tf_version_is_older_than("2.10.0"):
raise RuntimeError(
"The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 "
"or newer. Please fall back to a single GPU for now (see issue #75),"
"or downgrade to tensorflow 2.9.0."
)

# docstr-coverage: inherited
def __enter__(self) -> "DeviceAllocatorContextManager":
super().__enter__()
Expand Down

0 comments on commit 1029d9d

Please sign in to comment.