Skip to content

Commit

Permalink
[JAX FE][GHA] Add JAX/Flax Model Hub tests into GHA (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…26542)

**Details:** Add JAX Model Hub tests into GHA

**Ticket:** 142882

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Sep 12, 2024
1 parent f2ea004 commit deb702e
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 16 deletions.
9 changes: 7 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@
- 'tests/layer_tests/tensorflow2_keras_tests/**/*'
- 'tests/layer_tests/jax_tests/**/*'
- any: ['tests/model_hub_tests/**',
'!tests/model_hub_tests/pytorch/**/*']
'!tests/model_hub_tests/pytorch/**/*',
'!tests/model_hub_tests/jax/**/*']

'category: TFL FE':
- 'src/frontends/tensorflow_lite/**/*'
Expand All @@ -165,12 +166,16 @@
- 'tests/layer_tests/py_frontend_tests/test_torch_decoder.py'
- 'tests/layer_tests/py_frontend_tests/test_torch_frontend.py'
- any: ['tests/model_hub_tests/**',
'!tests/model_hub_tests/tensorflow/**/*']
'!tests/model_hub_tests/tensorflow/**/*',
'!tests/model_hub_tests/jax/**/*']

'category: JAX FE':
- 'src/frontends/jax/**/*'
- 'src/bindings/python/src/openvino/frontend/jax/**/*'
- 'tests/layer_tests/jax_tests/**/*'
- any: ['tests/model_hub_tests/**',
'!tests/model_hub_tests/tensorflow/**/*',
'!tests/model_hub_tests/pytorch/**/*']

'category: tools':
- any: ['tools/**',
Expand Down
116 changes: 116 additions & 0 deletions .github/workflows/job_jax_models_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
name: JAX/Flax Models tests

on:
workflow_call:
inputs:
runner:
description: 'Machine on which the tests would run'
type: string
required: true
container:
description: 'JSON to be converted to the value of the "container" configuration for the job'
type: string
required: false
default: '{"image": null}'
model_scope:
description: 'Scope of models for testing.'
type: string
required: true

permissions: read-all

jobs:
JAX_Models_Tests:
name: JAX/Flax Models tests
timeout-minutes: ${{ inputs.model_scope == 'precommit' && 35 || 35 }}
runs-on: ${{ inputs.runner }}
container: ${{ fromJSON(inputs.container) }}
defaults:
run:
shell: bash
env:
DEBIAN_FRONTEND: noninteractive # to prevent apt-get from waiting user input
OPENVINO_REPO: ${{ github.workspace }}/openvino
INSTALL_DIR: ${{ github.workspace }}/install
INSTALL_TEST_DIR: ${{ github.workspace }}/install/tests
MODEL_HUB_TESTS_INSTALL_DIR: ${{ github.workspace }}/install/tests/model_hub_tests
steps:
- name: Download OpenVINO package
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: openvino_package
path: ${{ env.INSTALL_DIR }}

- name: Download OpenVINO tokenizers extension
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: openvino_tokenizers_wheel
path: ${{ env.INSTALL_DIR }}

- name: Download OpenVINO tests package
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: openvino_tests
path: ${{ env.INSTALL_TEST_DIR }}

# Needed as ${{ github.workspace }} is not working correctly when using Docker
- name: Setup Variables
run: |
echo "OPENVINO_REPO=$GITHUB_WORKSPACE/openvino" >> "$GITHUB_ENV"
echo "INSTALL_DIR=$GITHUB_WORKSPACE/install" >> "$GITHUB_ENV"
echo "INSTALL_TEST_DIR=$GITHUB_WORKSPACE/install/tests" >> "$GITHUB_ENV"
echo "MODEL_HUB_TESTS_INSTALL_DIR=$GITHUB_WORKSPACE/install/tests/model_hub_tests" >> "$GITHUB_ENV"
- name: Extract OpenVINO packages
run: |
pushd ${INSTALL_DIR}
tar -xzf openvino_package.tar.gz -C ${INSTALL_DIR}
popd
pushd ${INSTALL_TEST_DIR}
tar -xzf openvino_tests.tar.gz -C ${INSTALL_DIR}
popd
- name: Fetch setup_python action
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
sparse-checkout: |
.github/actions/setup_python/action.yml
sparse-checkout-cone-mode: false
path: 'openvino'

- name: Setup Python 3.11
uses: ./openvino/.github/actions/setup_python
with:
version: '3.11'
should-setup-pip-paths: 'false'
self-hosted-runner: ${{ contains(inputs.runner, 'aks') }}

- name: Install OpenVINO Python wheels
run: |
# To enable pytest parallel features
python3 -m pip install pytest-xdist[psutil]
python3 -m pip install ${INSTALL_DIR}/tools/openvino-*
python3 -m pip install ${INSTALL_DIR}/openvino_tokenizers-*
- name: Install JAX tests requirements for precommit
run: |
python3 -m pip install -r ${MODEL_HUB_TESTS_INSTALL_DIR}/jax/requirements.txt
- name: JAX/Flax Models Tests from Hugging Face
if: ${{ inputs.model_scope == 'precommit' || inputs.model_scope == 'nightly' }}
run: |
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/jax/ -m ${TYPE} --html=${INSTALL_TEST_DIR}/TEST-jax_model_${{ inputs.model_scope }}_tests.html --self-contained-html -v
env:
TYPE: ${{ inputs.model_scope == 'precommit' && 'precommit' || 'nightly' }}
TEST_DEVICE: CPU

- name: Upload Test Results
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
if: ${{ !cancelled() }}
with:
name: test-results-jax-models-${{ inputs.model_scope }}
path: |
${{ env.INSTALL_TEST_DIR }}/TEST*.html
if-no-files-found: 'error'
12 changes: 11 additions & 1 deletion .github/workflows/ubuntu_22.yml
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,16 @@ jobs:
runner: 'ubuntu-22.04-16-cores'
model_scope: 'nightly_scope2'

JAX_Models_Tests_Precommit:
name: JAX/Flax Models tests
if: fromJSON(needs.smart_ci.outputs.affected_components).JAX_FE.test
needs: [ Docker, Build, Smart_CI, Openvino_tokenizers ]
uses: ./.github/workflows/job_jax_models_tests.yml
with:
runner: 'aks-linux-8-cores-16gb'
model_scope: 'precommit'
container: '{"image": "${{ fromJSON(needs.docker.outputs.images).ov_build.ubuntu_22_04_x64 }}", "volumes": ["/mount:/mount"]}'

NVIDIA_Plugin:
name: NVIDIA plugin
needs: [ Docker, Build, Smart_CI ]
Expand Down Expand Up @@ -508,7 +518,7 @@ jobs:
Overall_Status:
name: ci/gha_overall_status
needs: [Smart_CI, Build, Debian_Packages, Samples, Conformance, ONNX_Runtime, CXX_Unit_Tests, Python_Unit_Tests, TensorFlow_Layer_Tests,
CPU_Functional_Tests, TensorFlow_Models_Tests_Precommit, PyTorch_Models_Tests, NVIDIA_Plugin, Openvino_tokenizers, iGPU]
CPU_Functional_Tests, TensorFlow_Models_Tests_Precommit, PyTorch_Models_Tests, JAX_Models_Tests_Precommit, NVIDIA_Plugin, Openvino_tokenizers, iGPU]
if: ${{ always() }}
runs-on: ubuntu-latest
steps:
Expand Down
30 changes: 17 additions & 13 deletions src/bindings/python/src/openvino/frontend/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,25 @@ def ivalue_to_constant(ivalue, shared_memory=True):
if ov_type.is_static():
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, (list, tuple)):
# TODO 150596: remove this workaround
if len(ivalue) == 0:
return op.Constant(OVType.i64, Shape([0]), []).outputs()
assert len(ivalue) > 0, "Can't deduce type for empty list"
if isinstance(ivalue[0], (list, tuple)):
second_len = len(ivalue[0])
flattened_ivalue = []
for value in ivalue:
assert isinstance(value, (list, tuple)), "Can't deduce type for a list with both list and basic types."
assert len(value) == second_len or len(value) == 0, "Can't deduce type for nested list with different lengths."
flattened_ivalue.extend([filter_element(item) for item in value])
flattened_ivalue = [item for sublist in ivalue for item in sublist]
ov_type = _get_ov_type_from_value(flattened_ivalue[0])
assert ov_type.is_static(), f"Can't deduce type {flattened_ivalue[0].__class__} for list"
return op.Constant(ov_type, Shape([len(ivalue), second_len]), flattened_ivalue).outputs()
ivalue = [filter_element(item) for item in ivalue]
ov_type = _get_ov_type_from_value(ivalue[0])
try:
if isinstance(ivalue[0], (list, tuple)):
second_len = len(ivalue[0])
flattened_ivalue = []
for value in ivalue:
assert isinstance(value, (list, tuple)), "Can't deduce type for a list with both list and basic types."
assert len(value) == second_len or len(
value) == 0, "Can't deduce type for nested list with different lengths."
flattened_ivalue.extend([filter_element(item) for item in value])
flattened_ivalue = [item for sublist in ivalue for item in sublist]
ov_type = _get_ov_type_from_value(flattened_ivalue[0])
assert ov_type.is_static(), f"Can't deduce type {flattened_ivalue[0].__class__} for list"
return op.Constant(ov_type, Shape([len(ivalue), second_len]), flattened_ivalue).outputs()
ivalue = [filter_element(item) for item in ivalue]
ov_type = _get_ov_type_from_value(ivalue[0])
assert ov_type.is_static(), f"Can't deduce type {ivalue[0].__class__} for list"
except:
# TODO 150596: remove this workaround
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "openvino/op/add.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/erf.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/multiply.hpp"
Expand Down Expand Up @@ -59,6 +60,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"device_put", op::skip_node},
{"div", op::translate_1to1_match_2_inputs<v1::Divide>},
{"dot_general", op::translate_dot_general},
{"erf", op::translate_1to1_match_1_input<v0::Erf>},
{"exp", op::translate_1to1_match_1_input<v0::Exp>},
{"integer_pow", op::translate_integer_pow},
{"max", op::translate_1to1_match_2_inputs<v1::Maximum>},
Expand Down
1 change: 1 addition & 0 deletions tests/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ kornia==0.7.0
networkx<=3.3
timm==1.0.8
transformers~=4.44
flax<=0.9.0

--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.4.1
39 changes: 39 additions & 0 deletions tests/layer_tests/jax_tests/test_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(109734)


class TestErf(JaxLayerTest):
def _prepare_input(self):
# erf are mostly changing in a range [-4, 4]
x = rng.uniform(-4.0, 4.0, self.input_shape).astype(self.input_type)

x = jnp.array(x)
return [x]

def create_model(self, input_shape, input_type):
self.input_shape = input_shape
self.input_type = input_type

def jax_erf(x):
return jax.lax.erf(x)

return jax_erf, None, 'erf'

@pytest.mark.parametrize("input_shape", [[2], [3, 4]])
@pytest.mark.parametrize("input_type", [np.float16, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_erf(self, ie_device, precision, ir_version, input_shape, input_type):
self._test(*self.create_model(input_shape, input_type),
ie_device, precision,
ir_version)
12 changes: 12 additions & 0 deletions tests/model_hub_tests/jax/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import inspect

from models_hub_common.utils import get_params


def pytest_generate_tests(metafunc):
test_gen_attrs_names = list(inspect.signature(get_params).parameters)
params = get_params()
metafunc.parametrize(test_gen_attrs_names, params, scope="function")
7 changes: 7 additions & 0 deletions tests/model_hub_tests/jax/hf_transformers_models
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# type,name,mark,reason
# top transformers from https://huggingface.co/docs/transformers/index
vit-base,google/vit-base-patch16-224-in21k
albert-base-v2,albert/albert-base-v2,skip,No conversion rule found for iota pjit select_n
bart-base,facebook/bart-base,skip,No conversion rule found for and eq ge gt iota ne pjit scatter select_n
clip-vit-base-patch32,openai/clip-vit-base-patch32,skip,No conversion rule found for and argmax gather ge gt iota logistic lt ne pjit select_n
Mistral-tiny,ksmcg/Mistral-tiny,skip,No conversion rule found for and gather ge gt iota lt ne neg pjit select_n
40 changes: 40 additions & 0 deletions tests/model_hub_tests/jax/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import flax
import jax
import numpy as np
from models_hub_common.test_convert_model import TestConvertModel
from openvino import convert_model


def flattenize_pytree(outputs):
leaves, _ = jax.tree_util.tree_flatten(outputs)
return [np.array(i) if isinstance(i, jax.Array) else i for i in leaves]


class TestJaxConvertModel(TestConvertModel):
def get_inputs_info(self, _):
return None

def prepare_inputs(self, _):
inputs = getattr(self, 'inputs', self.example)
return inputs

def convert_model(self, model_obj):
if isinstance(model_obj, flax.linen.Module):
ov_model = convert_model(model_obj, example_input=self.example, verbose=True)
else:
# create JAXpr object
jaxpr = jax.make_jaxpr(model_obj)(**self.example)
ov_model = convert_model(jaxpr, verbose=True)
return ov_model

def infer_fw_model(self, model_obj, inputs):
if isinstance(inputs, dict):
fw_outputs = model_obj(**inputs)
elif isinstance(inputs, list):
fw_outputs = model_obj(*inputs)
else:
fw_outputs = model_obj(inputs)
return flattenize_pytree(fw_outputs)
10 changes: 10 additions & 0 deletions tests/model_hub_tests/jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-c ../../constraints.txt
numpy
pytest
pytest-html
transformers
requests
jax
jaxlib
flax
pillow
Loading

0 comments on commit deb702e

Please sign in to comment.