Skip to content

Commit

Permalink
yet more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Nov 28, 2023
1 parent 84eb144 commit 008b489
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions test/test_syn_players.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,27 @@ def get_data(N: int, sigma: float, max_scale: float, seed):
return (a1 * a2).astype(np.float32)


def parallel_avg_pool_cond(a, b):
close: np.ndarray = np.abs(a - b) < 1e-2

assert np.all(close), f"Keras-Proxy mismatch for approx avg pool: {np.sum(np.any(~close, axis=tuple(range(1,close.ndim))))} out of {a.shape[0]} samples are very different. Sample: {a[~close].ravel()[:5]} vs {b[~close].ravel()[:5]}"


@pytest.mark.parametrize('layer',
[
"PConcatenate()",
# "PMaxPool1D(2, padding='same')",
# "PMaxPool2D((2,2), padding='same')",
# "PMaxPool1D(2, padding='valid')",
# "PMaxPool2D((2,2), padding='valid')",
"PMaxPool1D(2, padding='same')",
"PMaxPool2D((2,2), padding='same')",
"PMaxPool1D(2, padding='valid')",
"PMaxPool2D((2,2), padding='valid')",
"Signature(1,6,3)"
# "PAvgPool1D(2, padding='same')",
# "PAvgPool2D((1,2), padding='same')",
# "PAvgPool2D((2,2), padding='same')",
# "PAvgPool1D(2, padding='valid')",
# "PAvgPool2D((1,2), padding='valid')",
# "PAvgPool2D((2,2), padding='valid')",
# "PFlatten()",
"PAvgPool1D(2, padding='same')",
"PAvgPool2D((1,2), padding='same')",
"PAvgPool2D((2,2), padding='same')",
"PAvgPool1D(2, padding='valid')",
"PAvgPool2D((1,2), padding='valid')",
"PAvgPool2D((2,2), padding='valid')",
"PFlatten()",
]
)
@pytest.mark.parametrize("N", [1000, 10])
Expand All @@ -96,7 +102,14 @@ def test_syn_players(layer, N: int, rnd_strategy: str, io_type: str, cover_facto
q = gfixed(1, 6, 3)
data = q(data).numpy()

run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive)
cond = None
if 'AvgPool' in layer and io_type == 'io_parallel':
if cover_factor < 1.0:
# pass
pytest.skip('AvgPool\'s accum is not configurable for io_parallel, and cover_factor < 1.0 leads to overflow cannot be emulated')
cond = parallel_avg_pool_cond
# mark as xfail if io_parallel and cover_factor < 1.0
run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive, cond=cond)


if __name__ == '__main__':
Expand Down

0 comments on commit 008b489

Please sign in to comment.