Skip to content

Commit

Permalink
Fix BT CI (#1872)
Browse files Browse the repository at this point in the history
* fix bt test failures due to default sdpa attention

* exclude macos13+py3.8

* update tr

* check transformers version
  • Loading branch information
IlyasMoutawwakil authored May 27, 2024
1 parent ff0a0b3 commit e81bd73
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 28 deletions.
54 changes: 30 additions & 24 deletions .github/workflows/test_bettertransformer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: BetterTransformer / Python - Test

on:
push:
branches: [ main ]
branches: [main]
pull_request:
branches: [ main ]
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand All @@ -17,29 +17,35 @@ jobs:
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04, macos-13]
exclude: [{ python-version: 3.8, os: macos-13 }]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests]
pip install --no-cache-dir --upgrade torch torchvision torchaudio
pip install accelerate
- name: Test on pytorch stable
working-directory: tests
run: |
pytest bettertransformer/test_*.py -s -vvvvv
- name: Install dependencies 2
run: |
pip uninstall -y torch torchvision torchaudio
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
- name: Test on pytorch nightly
working-directory: tests
run: |
pytest bettertransformer/test_*.py -s -vvvvv
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install .[tests]
pip install --no-cache-dir --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install accelerate
- name: Test with stable pytorch
working-directory: tests
run: |
pytest bettertransformer -s -vvvvv
- name: Install dependencies 2
run: |
pip uninstall -y torch torchvision torchaudio
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
- name: Test with nightly pytorch
working-directory: tests
run: |
pytest bettertransformer -s -vvvvv
9 changes: 7 additions & 2 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from transformers.pipelines import infer_framework_load_model

from ..bettertransformer import BetterTransformer
from ..utils import is_onnxruntime_available
from ..utils import check_if_transformers_greater, is_onnxruntime_available
from ..utils.file_utils import find_files_matching_pattern


Expand Down Expand Up @@ -179,7 +179,12 @@ def load_bettertransformer(
**kwargs,
):
if model_kwargs is None:
model_kwargs = {}
# the argument was first introduced in 4.36.0 but most models didn't have an sdpa implementation then
# see https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/modeling_utils.py#L1258
if check_if_transformers_greater("4.36.0"):
model_kwargs = {"attn_implementation": "eager"}
else:
model_kwargs = {}

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_inference_speed(self):
"""
model_name = "bert-base-uncased"

hf_model = AutoModel.from_pretrained(model_name).eval()
hf_model = AutoModel.from_pretrained(model_name, attn_implementation="eager").eval()
bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

BATCH_SIZE = 8
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs):
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs)

torch.manual_seed(0)
hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
random_config = hf_random_model.config

hf_random_model = hf_random_model.eval()
Expand Down

0 comments on commit e81bd73

Please sign in to comment.