Skip to content

Commit

Permalink
fix grad test
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Nov 28, 2023
1 parent 7adf7e3 commit ca59326
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def _run_gradient_test(model, data):

def run_model_test(model: keras.Model, cover_factor: float | None, data, io_type: str, backend: str, dir: str, aggressive: bool, no_exact_match: bool = False, skip_sl_test=False, test_gard=False):
data_len = data.shape[0] if isinstance(data, np.ndarray) else data[0].shape[0]
if test_gard:
_run_gradient_test(model, data)
if cover_factor is not None:
trace_minmax(model, data, cover_factor=cover_factor, bsz=data_len)
proxy = to_proxy_model(model, aggressive=aggressive)
try:
if not skip_sl_test:
_run_model_sl_test(model, proxy, data, dir)
if test_gard:
_run_gradient_test(model, data)
_run_model_proxy_match_test(model, proxy, data, cover_factor)
_run_synth_match_test(proxy, data, io_type, backend, dir)
except AssertionError as e:
Expand Down
2 changes: 1 addition & 1 deletion test/test_syn_hlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_syn_hlayers(layer, N: int, rnd_strategy: str, io_type: str, cover_facto
model = create_model(layer=layer, rnd_strategy=rnd_strategy, io_type=io_type)
data = get_data(N, 1, 1, seed)

run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive, test_gard=True)
run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive, test_gard=N > 100)


if __name__ == '__main__':
Expand Down

0 comments on commit ca59326

Please sign in to comment.