Skip to content

Commit

Permalink
Fix np.matrix complex types
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Jul 8, 2024
1 parent 799f296 commit 13e0f1b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion stanio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"stan_variables",
]

__version__ = "0.5.0"
__version__ = "0.5.1"
2 changes: 1 addition & 1 deletion stanio/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def process_value(val: Any) -> Any:
if numpy_val.dtype.kind in "iuf":
return numpy_val.tolist()
if numpy_val.dtype.kind == "c":
return np.stack([numpy_val.real, numpy_val.imag], axis=-1).tolist()
return np.stack([np.asarray(numpy_val.real), np.asarray(numpy_val.imag)], axis=-1).tolist()
if numpy_val.dtype.kind == "b":
return numpy_val.astype(int).tolist()

Expand Down
9 changes: 9 additions & 0 deletions test/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def test_complex_numbers_np(TMPDIR) -> None:
compare_before_after(file_complex, dict_complex, dict_complex_exp)


def test_complex_numbers_np_matrix(TMPDIR) -> None:
a_raw = np.random.rand(21, 21, 2)
a = np.matrix(a_raw[:, :, 0] + 1j * a_raw[:, :, 1])
dict_complex = {"a": a}
dict_complex_exp = {"a": a_raw}
file_complex = os.path.join(TMPDIR, "complex_np_mat.json")
compare_before_after(file_complex, dict_complex, dict_complex_exp)


def test_tuples(TMPDIR) -> None:
dict_tuples = {
"a": (1, 2, 3),
Expand Down

0 comments on commit 13e0f1b

Please sign in to comment.