diff --git a/flax/nnx/tests/compat/test_module.py b/flax/nnx/tests/compat/module_test.py similarity index 95% rename from flax/nnx/tests/compat/test_module.py rename to flax/nnx/tests/compat/module_test.py index df76033510..d66883f5e7 100644 --- a/flax/nnx/tests/compat/test_module.py +++ b/flax/nnx/tests/compat/module_test.py @@ -14,6 +14,7 @@ import dataclasses +from absl.testing import absltest import jax import jax.numpy as jnp @@ -21,7 +22,7 @@ from flax.nnx import compat -class TestCompatModule: +class TestCompatModule(absltest.TestCase): def test_compact_basic(self): class Linear(compat.Module): dout: int @@ -131,4 +132,7 @@ def __call__(self, x): assert y.shape == (1, 5) assert hasattr(bar, 'foo') - assert isinstance(bar.foo, Foo) \ No newline at end of file + assert isinstance(bar.foo, Foo) + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/compat/test_wrappers.py b/flax/nnx/tests/compat/wrappers_test.py similarity index 100% rename from flax/nnx/tests/compat/test_wrappers.py rename to flax/nnx/tests/compat/wrappers_test.py diff --git a/flax/nnx/tests/test_containers.py b/flax/nnx/tests/containers_test.py similarity index 92% rename from flax/nnx/tests/test_containers.py rename to flax/nnx/tests/containers_test.py index 4757d494ee..97785e7658 100644 --- a/flax/nnx/tests/test_containers.py +++ b/flax/nnx/tests/containers_test.py @@ -14,9 +14,10 @@ from flax import nnx +from absl.testing import absltest -class TestContainers: +class TestContainers(absltest.TestCase): def test_unbox(self): x = nnx.Param( 1, @@ -58,3 +59,7 @@ def __init__(self) -> None: assert module.x.value == 12 assert vars(module)['x'].raw_value == 12 + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/test_graph_utils.py b/flax/nnx/tests/graph_utils_test.py similarity index 100% rename from flax/nnx/tests/test_graph_utils.py rename to flax/nnx/tests/graph_utils_test.py diff --git a/flax/nnx/tests/test_helpers.py b/flax/nnx/tests/helpers_test.py similarity index 93% rename from flax/nnx/tests/test_helpers.py rename to flax/nnx/tests/helpers_test.py index 8a7cec4dbc..e97b5c0828 100644 --- a/flax/nnx/tests/test_helpers.py +++ b/flax/nnx/tests/helpers_test.py @@ -16,7 +16,8 @@ import jax.numpy as jnp import optax -from numpy.testing import assert_array_equal +from absl.testing import absltest +import numpy as np from flax import linen from flax import nnx @@ -90,7 +91,7 @@ def test_nnx_linen_sequential_equivalence(self): ).value out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) variables = model.init(key2, x) for layer_index in range(2): @@ -100,4 +101,9 @@ def test_nnx_linen_sequential_equivalence(self): ][f'layers_{layer_index}'][param] out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) + + +if __name__ == '__main__': + absltest.main() + diff --git a/flax/nnx/tests/test_ids.py b/flax/nnx/tests/ids_test.py similarity index 91% rename from flax/nnx/tests/test_ids.py rename to flax/nnx/tests/ids_test.py index d72490c836..49d71ef330 100644 --- a/flax/nnx/tests/test_ids.py +++ b/flax/nnx/tests/ids_test.py @@ -14,6 +14,7 @@ import copy +from absl.testing import absltest from flax.nnx.nnx import ids @@ -28,3 +29,7 @@ def test_hashable(self): id1dc = copy.deepcopy(id1) assert hash(id1) != hash(id1c) assert hash(id1) != hash(id1dc) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/test_integration.py b/flax/nnx/tests/integration_test.py similarity index 98% rename from flax/nnx/tests/test_integration.py rename to flax/nnx/tests/integration_test.py index c473562b85..e33eebe7ed 100644 --- a/flax/nnx/tests/test_integration.py +++ b/flax/nnx/tests/integration_test.py @@ -14,6 +14,7 @@ import typing as tp +from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np @@ -23,7 +24,7 @@ A = tp.TypeVar('A') -class TestIntegration: +class TestIntegration(absltest.TestCase): def test_shared_modules(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): @@ -257,3 +258,7 @@ def __call__(self, x): intermediates, state = state.split(nnx.Intermediate, ...) assert 'y' in intermediates + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/test_metrics.py b/flax/nnx/tests/metrics_test.py similarity index 94% rename from flax/nnx/tests/test_metrics.py rename to flax/nnx/tests/metrics_test.py index 9a84cceb9a..f05a244d29 100644 --- a/flax/nnx/tests/test_metrics.py +++ b/flax/nnx/tests/metrics_test.py @@ -17,6 +17,7 @@ from flax import nnx +from absl.testing import absltest from absl.testing import parameterized @@ -63,4 +64,8 @@ def test_multimetric(self): metrics.reset() values = metrics.compute() self.assertTrue(jnp.isnan(values['accuracy'])) - self.assertTrue(jnp.isnan(values['loss'])) \ No newline at end of file + self.assertTrue(jnp.isnan(values['loss'])) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/test_module.py b/flax/nnx/tests/module_test.py similarity index 100% rename from flax/nnx/tests/test_module.py rename to flax/nnx/tests/module_test.py diff --git a/flax/nnx/tests/nn/test_attention.py b/flax/nnx/tests/nn/attention_test.py similarity index 96% rename from flax/nnx/tests/nn/test_attention.py rename to flax/nnx/tests/nn/attention_test.py index 9c45264d9c..4f44eae607 100644 --- a/flax/nnx/tests/nn/test_attention.py +++ b/flax/nnx/tests/nn/attention_test.py @@ -19,13 +19,14 @@ from flax import nnx from flax.typing import Dtype, PrecisionLike -from numpy.testing import assert_array_equal +import numpy as np import typing as tp from absl.testing import parameterized +from absl.testing import absltest -class TestMultiHeadAttention: +class TestMultiHeadAttention(absltest.TestCase): def test_basic(self): module = nnx.MultiHeadAttention( num_heads=2, @@ -167,4 +168,8 @@ def test_nnx_attention_equivalence( out_nnx = model_nnx(x) out, cache = model.apply(variables, x, mutable=['cache']) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/nn/test_conv.py b/flax/nnx/tests/nn/conv_test.py similarity index 96% rename from flax/nnx/tests/nn/test_conv.py rename to flax/nnx/tests/nn/conv_test.py index 41a3a8044e..6f33d84fd3 100644 --- a/flax/nnx/tests/nn/test_conv.py +++ b/flax/nnx/tests/nn/conv_test.py @@ -16,10 +16,11 @@ import typing as tp import jax +from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp from jax.lax import Precision -from numpy.testing import assert_array_equal +import numpy as np from flax import linen from flax import nnx @@ -102,7 +103,7 @@ def test_nnx_linen_conv_equivalence( out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) @parameterized.product( strides=[None, (2, 3)], @@ -166,4 +167,8 @@ def test_nnx_linen_convtranspose_equivalence( out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) \ No newline at end of file + np.testing.assert_array_equal(out, out_nnx) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/flax/nnx/tests/nn/test_embed.py b/flax/nnx/tests/nn/embed_test.py similarity index 87% rename from flax/nnx/tests/nn/test_embed.py rename to flax/nnx/tests/nn/embed_test.py index faababe008..eb0551f509 100644 --- a/flax/nnx/tests/nn/test_embed.py +++ b/flax/nnx/tests/nn/embed_test.py @@ -15,9 +15,10 @@ import typing as tp import jax +from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp -from numpy.testing import assert_array_equal +import numpy as np from flax import linen from flax import nnx @@ -62,11 +63,15 @@ def test_nnx_linen_equivalence( out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) x = jax.numpy.ones((10,), dtype=input_dtype) * 10 out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) - assert_array_equal(out, out_nnx) - assert_array_equal(jax.numpy.isnan(out).all(), jax.numpy.array([True])) + np.testing.assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(jax.numpy.isnan(out).all(), jax.numpy.array([True])) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/nn/test_linear.py b/flax/nnx/tests/nn/linear_test.py similarity index 94% rename from flax/nnx/tests/nn/test_linear.py rename to flax/nnx/tests/nn/linear_test.py index aa55eb6427..46374d1bd2 100644 --- a/flax/nnx/tests/nn/test_linear.py +++ b/flax/nnx/tests/nn/linear_test.py @@ -16,9 +16,10 @@ import jax import jax.numpy as jnp +from absl.testing import absltest from absl.testing import parameterized from jax.lax import Precision -from numpy.testing import assert_array_equal +import numpy as np from flax import linen from flax import nnx @@ -91,7 +92,7 @@ def test_nnx_linear_equivalence( out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) @parameterized.product( einsum_str=['defab,bcef->adefc', 'd...ab,bc...->ad...c'], @@ -139,7 +140,7 @@ def test_nnx_einsum_equivalence( variables['params']['bias'] = model_nnx.bias.value out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) variables = model.init(key, x) model_nnx.kernel.value = variables['params']['kernel'] @@ -148,4 +149,8 @@ def test_nnx_einsum_equivalence( model_nnx.bias.value = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) + np.testing.assert_array_equal(out, out_nnx) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/nn/test_lora.py b/flax/nnx/tests/nn/lora_test.py similarity index 100% rename from flax/nnx/tests/nn/test_lora.py rename to flax/nnx/tests/nn/lora_test.py diff --git a/flax/nnx/tests/nn/test_normalization.py b/flax/nnx/tests/nn/normalization_test.py similarity index 92% rename from flax/nnx/tests/nn/test_normalization.py rename to flax/nnx/tests/nn/normalization_test.py index 3e30febcf6..add352b7c5 100644 --- a/flax/nnx/tests/nn/test_normalization.py +++ b/flax/nnx/tests/nn/normalization_test.py @@ -16,8 +16,9 @@ import jax import jax.numpy as jnp +from absl.testing import absltest from absl.testing import parameterized -from numpy.testing import assert_array_equal +import numpy as np from flax import linen from flax import nnx @@ -29,14 +30,14 @@ class TestLinenConsistency(parameterized.TestCase): dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], - mask=[None, jnp.array([True, False, True, False, True])], + mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_batchnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, - mask: tp.Optional[jax.Array], + mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): @@ -99,20 +100,20 @@ def __call__(self, x, *, mask=None): nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) - assert_array_equal(linen_out, nnx_out) + np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], - mask=[None, jnp.array([True, False, True, False, True])], + mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_layernorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, - mask: tp.Optional[jax.Array], + mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): @@ -171,20 +172,20 @@ def __call__(self, x, *, mask=None): nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) - assert_array_equal(linen_out, nnx_out) + np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], - mask=[None, jnp.array([True, False, True, False, True])], + mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_rmsnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, - mask: tp.Optional[jax.Array], + mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): @@ -243,4 +244,8 @@ def __call__(self, x, *, mask=None): nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) - assert_array_equal(linen_out, nnx_out) + np.testing.assert_array_equal(linen_out, nnx_out) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/nn/test_stochastic.py b/flax/nnx/tests/nn/stochastic_test.py similarity index 100% rename from flax/nnx/tests/nn/test_stochastic.py rename to flax/nnx/tests/nn/stochastic_test.py diff --git a/flax/nnx/tests/test_optimizer.py b/flax/nnx/tests/optimizer_test.py similarity index 96% rename from flax/nnx/tests/test_optimizer.py rename to flax/nnx/tests/optimizer_test.py index a7e0310f18..b612ca3b34 100644 --- a/flax/nnx/tests/test_optimizer.py +++ b/flax/nnx/tests/optimizer_test.py @@ -19,6 +19,7 @@ from flax import nnx +from absl.testing import absltest from absl.testing import parameterized @@ -117,4 +118,8 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch] state.update(grads=grads, values=loss_fn(state.model)) initial_loss = state.metrics.compute() state.update(grads=grads, values=loss_fn(state.model)) - self.assertTrue(state.metrics.compute() < initial_loss) \ No newline at end of file + self.assertTrue(state.metrics.compute() < initial_loss) + + +if __name__ == '__main__': + absltest.main() diff --git a/flax/nnx/tests/test_partitioning.py b/flax/nnx/tests/partitioning_test.py similarity index 100% rename from flax/nnx/tests/test_partitioning.py rename to flax/nnx/tests/partitioning_test.py diff --git a/flax/nnx/tests/test_rngs.py b/flax/nnx/tests/rngs_test.py similarity index 100% rename from flax/nnx/tests/test_rngs.py rename to flax/nnx/tests/rngs_test.py diff --git a/flax/nnx/tests/test_spmd.py b/flax/nnx/tests/spmd_test.py similarity index 93% rename from flax/nnx/tests/test_spmd.py rename to flax/nnx/tests/spmd_test.py index 0353bfc535..15808e0800 100644 --- a/flax/nnx/tests/test_spmd.py +++ b/flax/nnx/tests/spmd_test.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from absl.testing import absltest import jax import jax.numpy as jnp import optax -from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec from flax import nnx -class TestSPMD: - @jtu.skip_on_devices('cpu', 'gpu') +class TestSPMD(absltest.TestCase): def test_init(self): + if jax.device_count() < 4: + self.skipTest('At least 4 devices required') class Foo(nnx.Module): def __init__(self): self.w = nnx.Param( @@ -98,3 +99,8 @@ def __call__(self, x): assert state_spec.params['w'].value == PartitionSpec('row', 'col') assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') + + +if __name__ == '__main__': + absltest.main() + diff --git a/flax/nnx/tests/test_state.py b/flax/nnx/tests/state_test.py similarity index 100% rename from flax/nnx/tests/test_state.py rename to flax/nnx/tests/state_test.py diff --git a/flax/nnx/tests/test_transforms.py b/flax/nnx/tests/transforms_test.py similarity index 100% rename from flax/nnx/tests/test_transforms.py rename to flax/nnx/tests/transforms_test.py diff --git a/flax/nnx/tests/test_variable.py b/flax/nnx/tests/variable_test.py similarity index 93% rename from flax/nnx/tests/test_variable.py rename to flax/nnx/tests/variable_test.py index de5c5c52c0..5b3e899490 100644 --- a/flax/nnx/tests/test_variable.py +++ b/flax/nnx/tests/variable_test.py @@ -17,12 +17,13 @@ import jax import jax.numpy as jnp +from absl.testing import absltest from flax import nnx A = tp.TypeVar('A') -class TestVariableState: +class TestVariableState(absltest.TestCase): def test_pytree(self): r1 = nnx.VariableState(nnx.Param, 1) assert r1.value == 1 @@ -62,3 +63,7 @@ def __call__(self, x: jax.Array): x = jax.numpy.ones((3,)) y = linear(x) assert y.shape == (4,) + + +if __name__ == '__main__': + absltest.main()