Skip to content

Commit

Permalink
add checks on copying arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Oct 28, 2024
1 parent 4f744fd commit d98251a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/test_models/test_nn/deepar_native/test_deepar_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def test_deepar_make_samples(df_name, scale, weights, cat_columns, request):
assert ts_samples[i]["segment"] == "segment_1"
for key in expected_sample:
np.testing.assert_equal(ts_samples[i][key], expected_sample[key])
if "categorical" in key:
for column in ts_samples[i][key]:
assert ts_samples[i][key][column].base is not None
else:
if key != "weight":
assert ts_samples[i][key].base is not None


@pytest.mark.parametrize("encoder_length", [1, 2, 10])
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_nn/test_deepstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def test_deepstate_make_samples(df_name, cat_columns, request):
assert ts_samples[i]["segment"] == "segment_1"
for key in expected_sample:
np.testing.assert_equal(ts_samples[i][key], expected_sample[key])
if "categorical" in key:
for column in ts_samples[i][key]:
assert ts_samples[i][key][column].base is not None
else:
assert ts_samples[i][key].base is not None


def test_save_load(example_tsds):
Expand Down

0 comments on commit d98251a

Please sign in to comment.