diff --git a/test/test_syn_players.py b/test/test_syn_players.py index ee39f3c..95f1c28 100644 --- a/test/test_syn_players.py +++ b/test/test_syn_players.py @@ -62,21 +62,27 @@ def get_data(N: int, sigma: float, max_scale: float, seed): return (a1 * a2).astype(np.float32) +def parallel_avg_pool_cond(a, b): + close: np.ndarray = np.abs(a - b) < 1e-2 + + assert np.all(close), f"Keras-Proxy mismatch for approx avg pool: {np.sum(np.any(~close, axis=tuple(range(1,close.ndim))))} out of {a.shape[0]} samples are very different. Sample: {a[~close].ravel()[:5]} vs {b[~close].ravel()[:5]}" + + @pytest.mark.parametrize('layer', [ "PConcatenate()", - # "PMaxPool1D(2, padding='same')", - # "PMaxPool2D((2,2), padding='same')", - # "PMaxPool1D(2, padding='valid')", - # "PMaxPool2D((2,2), padding='valid')", + "PMaxPool1D(2, padding='same')", + "PMaxPool2D((2,2), padding='same')", + "PMaxPool1D(2, padding='valid')", + "PMaxPool2D((2,2), padding='valid')", "Signature(1,6,3)" - # "PAvgPool1D(2, padding='same')", - # "PAvgPool2D((1,2), padding='same')", - # "PAvgPool2D((2,2), padding='same')", - # "PAvgPool1D(2, padding='valid')", - # "PAvgPool2D((1,2), padding='valid')", - # "PAvgPool2D((2,2), padding='valid')", - # "PFlatten()", + "PAvgPool1D(2, padding='same')", + "PAvgPool2D((1,2), padding='same')", + "PAvgPool2D((2,2), padding='same')", + "PAvgPool1D(2, padding='valid')", + "PAvgPool2D((1,2), padding='valid')", + "PAvgPool2D((2,2), padding='valid')", + "PFlatten()", ] ) @pytest.mark.parametrize("N", [1000, 10]) @@ -96,7 +102,14 @@ def test_syn_players(layer, N: int, rnd_strategy: str, io_type: str, cover_facto q = gfixed(1, 6, 3) data = q(data).numpy() - run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive) + cond = None + if 'AvgPool' in layer and io_type == 'io_parallel': + if cover_factor < 1.0: + # pass + pytest.skip('AvgPool\'s accum is not configurable for io_parallel, and cover_factor < 1.0 leads to overflow cannot be emulated') + cond = parallel_avg_pool_cond + # mark as xfail if io_parallel and cover_factor < 1.0 + run_model_test(model, cover_factor, data, io_type, backend, dir, aggressive, cond=cond) if __name__ == '__main__':