@@ -18,6 +18,7 @@ python3 --version
18
18
# Check cuda
19
19
nvidia-smi
20
20
nvcc --version
21
+ echo " LD_LIBRARY_PATH before ${LD_LIBRARY_PATH} "
21
22
22
23
cd " src/github/keras"
23
24
pip install -U pip setuptools
43
44
44
45
if [ " $KERAS_BACKEND " == " jax" ]
45
46
then
47
+ export JAX_TRACEBACK_FILTERING=off
48
+
46
49
echo " JAX backend detected."
47
50
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
48
51
pip uninstall -y keras keras-nightly
49
52
python3 -c ' import jax;print(jax.__version__);print(jax.default_backend())'
50
53
# Raise error if GPU is not detected.
51
54
python3 -c ' import jax;assert jax.default_backend().lower() == "gpu"'
52
55
56
+ echo " LD_LIBRARY_PATH after ${LD_LIBRARY_PATH} "
57
+
53
58
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
54
59
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
55
60
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
56
- pytest keras --ignore keras/src/applications \
57
- --ignore keras/src/layers/merging/merging_test.py \
58
- --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
59
- --ignore keras/src/backend/jax/distribution_lib_test.py \
60
- --ignore keras/src/distribution/distribution_lib_test.py \
61
- --cov=keras \
62
- --cov-config=pyproject.toml
61
+ # pytest keras --ignore keras/src/applications \
62
+ # --ignore keras/src/layers/merging/merging_test.py \
63
+ # --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
64
+ # --ignore keras/src/backend/jax/distribution_lib_test.py \
65
+ # --ignore keras/src/distribution/distribution_lib_test.py \
66
+ # --cov=keras \
67
+ # --cov-config=pyproject.toml
68
+
69
+ pytest -s keras/src/utils/jax_layer_test.py
63
70
64
- pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
71
+ # pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
65
72
fi
66
73
67
74
if [ " $KERAS_BACKEND " == " torch" ]
0 commit comments