Skip to content

Commit

Permalink
Unpin version and make ground-truth values dependent on version
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 21, 2024
1 parent e48db06 commit 5e4149c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main
tqdm
transformers[sentencepiece]>=4.46.0
transformers
13 changes: 9 additions & 4 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest
import pytest_cases
import torch
import transformers

from brevitas import config
from brevitas import torch_version
Expand All @@ -40,6 +41,10 @@ def allexact(x, y):
return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False)


def transformers_version_ge(required_version: str):
return version.parse(required_version) >= version.parse(transformers.__version__)


# Check that all args in args are used
def validate_args(args):
a = vars(args)
Expand Down Expand Up @@ -207,14 +212,14 @@ def test_small_models_toggle_run_args_pt_ge_2_4(
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "fx",
"bias_corr": True,
"float_ppl": 33312.0, # 33239.5,
"quant_ppl": 33056.0}, # 33283.75390625},
"float_ppl": 33312.0 if transformers_version_ge('4.46.0') else 33239.5,
"quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33283.75390625},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"act_equalization": "layerwise",
"gptq": True,
"float_ppl": 31056.0, # 31274.05078125
"quant_ppl": 33056.0},]) # 33139.23046875},])
"float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125,
"quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},])
def acc_args_and_acc(default_run_args, request):
args = default_run_args
run_dict = request.param
Expand Down

0 comments on commit 5e4149c

Please sign in to comment.