Skip to content

Commit

Permalink
More llm test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 13, 2024
1 parent e723507 commit 604e3c4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
if torch_version >= packaging.version.parse('2.4'):
RMSNorm = nn.RMSNorm
else:
RMSNorm = object

class PlaceholderRMSNorm:
pass

RMSNorm = PlaceholderRMSNorm

__all__ = [
'GraphActivationEqualization',
Expand Down
5 changes: 5 additions & 0 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
import logging
import os
import platform
import shutil

import numpy as np
Expand Down Expand Up @@ -401,6 +402,10 @@ def test_small_models_quant_layer(caplog, layer_args):
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(
args,
'graph_rotation') and args.graph_rotation == 'fx' and platform.system() == 'win32':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)

Expand Down

0 comments on commit 604e3c4

Please sign in to comment.