Skip to content

Commit

Permalink
Merge pull request #550 from larrybradley/fix-numpy2.1-copy
Browse files Browse the repository at this point in the history
Update __array__ to copy if needed
  • Loading branch information
larrybradley authored May 2, 2024
2 parents 5f23252 + e6fb68d commit 15393a9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
7 changes: 5 additions & 2 deletions regions/core/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import astropy.units as u
import numpy as np
from astropy.utils import minversion

__all__ = ['RegionMask']

COPY_IF_NEEDED = False if not minversion(np, '2.0.0.dev') else None


class RegionMask:
"""
Expand Down Expand Up @@ -40,12 +43,12 @@ def __init__(self, data, bbox):
self.bbox = bbox
self._mask = (self.data == 0)

def __array__(self, dtype=None, copy=None):
def __array__(self, dtype=None, copy=COPY_IF_NEEDED):
"""
Array representation of the mask data array (e.g., for
matplotlib).
"""
return np.asarray(self.data, dtype=dtype)
return np.array(self.data, dtype=dtype, copy=copy)

@property
def shape(self):
Expand Down
6 changes: 0 additions & 6 deletions regions/core/tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def test_mask_copy():
mask_copy[0, 0] = 100.0
assert mask.data[0, 0] == 100.0

# needs to copy because of the dtype change
mask = RegionMask(np.ones((10, 10)), bbox)
mask_copy = np.array(mask, copy=False, dtype=int)
mask_copy[0, 0] = 100
assert mask.data[0, 0] == 1.0

# no copy
mask = RegionMask(np.ones((10, 10)), bbox)
mask_copy = np.asarray(mask)
Expand Down

0 comments on commit 15393a9

Please sign in to comment.