Skip to content

Commit

Permalink
Add support for patches storage in baseimage
Browse files Browse the repository at this point in the history
  • Loading branch information
ppizarror committed Aug 13, 2024
1 parent 382bfa7 commit 542ed75
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion MLStructFP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__description__ = 'Machine learning structural floor plan dataset'
__keywords__ = ['ml', 'ai', 'floor-plan', 'architectural', 'dataset', 'cnn']
__email__ = '[email protected]'
__version__ = '0.6.0'
__version__ = '0.6.1'

# URL
__url__ = 'https://github.com/MLSTRUCT/MLSTRUCT-FP'
Expand Down
25 changes: 17 additions & 8 deletions MLStructFP/db/image/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,26 @@
import numpy as np
import os

from abc import ABC, abstractmethod

if TYPE_CHECKING:
from MLStructFP.db._c_rect import Rect
from MLStructFP.db._floor import Floor

TYPE_IMAGE: str = 'uint8'


class BaseImage(object):
class BaseImage(ABC):
"""
Base dataset image object.
"""
_image_size: int
_images: List['np.ndarray']
_images: List['np.ndarray'] # List of stored images during make_region
_names: List[str]
_path: str
_save_images: bool

patches: List['np.ndarray'] # Additional stored images
save: bool

def __init__(self, path: str, save_images: bool, image_size_px: int) -> None:
Expand All @@ -46,18 +50,28 @@ def __init__(self, path: str, save_images: bool, image_size_px: int) -> None:
make_dirs(path)
assert os.path.isdir(path), f'Path <{path}> does not exist'

super(ABC, self).__init__()
self._image_size = image_size_px
self._images = []
self._names = [] # List of image names
self._path = path
self._save_images = save_images # Caution, this can be file expensive

self.patches = []
self.save = True

@property
def image_shape(self) -> Tuple[int, int]:
return self._image_size, self._image_size

@abstractmethod
def close(self) -> None:
"""
Close and delete all generated figures.
"""
raise NotImplementedError()

@abstractmethod
def make_rect(self, rect: 'Rect', crop_length: NumberType) -> Tuple[int, 'np.ndarray']:
"""
Generate image for the perimeter of a given rectangle.
Expand All @@ -68,6 +82,7 @@ def make_rect(self, rect: 'Rect', crop_length: NumberType) -> Tuple[int, 'np.nda
"""
raise NotImplementedError()

@abstractmethod
def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax: NumberType,
floor: 'Floor', rect: Optional['Rect'] = None) -> Tuple[int, 'np.ndarray']:
"""
Expand Down Expand Up @@ -106,12 +121,6 @@ def export(self, filename: str, close: bool = True, compressed: bool = True) ->
if close:
self.close()

def close(self) -> None:
"""
Close and delete all generated figures.
"""
raise NotImplementedError()

def get_images(self) -> 'np.ndarray':
"""
:return: Images as numpy ndarray
Expand Down
11 changes: 4 additions & 7 deletions MLStructFP/db/image/_rect_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax
def close(self) -> None:
"""
Close and delete all generated figures.
This function also restores plot engine.
"""
if not self._initialized:
raise RuntimeError('Exporter not initialized, it cannot be closed')
Expand All @@ -245,12 +246,8 @@ def close(self) -> None:
self._images.clear()
self._names.clear()

self._initialized = False

@staticmethod
def restore_plot() -> None:
"""
Restore plot backend.
"""
# Restore plot
if plt.get_backend() == 'agg':
plt.switch_backend(INITIAL_BACKEND)

self._initialized = False
21 changes: 11 additions & 10 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3110,8 +3110,8 @@
"# Apply mutations (default 0, 1, 1)\n",
"plot_floor.mutate(30, 1, 1)\n",
"\n",
"crop_length = 10 # Meters to fix the range\n",
"image_size_px = 256 # Which pixel target to resize\n",
"crop_length = 10 # Meters to fix the range\n",
"image_size_px = 256 # Which pixel target to resize\n",
"\n",
"plot_floor.plot_complex()"
]
Expand All @@ -3130,32 +3130,33 @@
"path_photo = f'.out/example_photo_{image_size_px}_{plot_floor.id}/'\n",
"export_binary = f'.out/image_binary_{plot_floor.id}'\n",
"\n",
"image_binary = RectBinaryImage(path=path_binary, save_images=True, image_size_px = image_size_px).init()\n",
"image_photo = RectFloorPhoto(path=path_photo, save_images=True, image_size_px = image_size_px)\n",
"image_binary = RectBinaryImage(path=path_binary, save_images=True, image_size_px=image_size_px).init()\n",
"image_photo = RectFloorPhoto(path=path_photo, save_images=True, image_size_px=image_size_px)\n",
"\n",
"for r in plot_floor.rect: image_binary.make_rect(r, crop_length=crop_length)\n",
"for r in plot_floor.rect: image_photo.make_rect(r, crop_length=crop_length)\n",
"\n",
"image_binary.export(export_binary)\n",
"image_photo.export(f'.out/image_photo_{plot_floor.id}')\n",
"image_binary.restore_plot()\n",
"image_binary.close()\n",
"\n",
"\n",
"# Compare images from binary/photo\n",
"def compare_images(rectfile: int):\n",
" flo = open(f'{export_binary}_{image_size_px}_files.csv', 'r')\n",
" flnm = ''\n",
" j = 0\n",
" for i in flo:\n",
" if j-1 == rectfile:\n",
" if j - 1 == rectfile:\n",
" flnm = i.split(',')[1].strip() + '.png'\n",
" break\n",
" j+=1\n",
" j += 1\n",
" assert flnm != ''\n",
"\n",
" fig = plt.figure(dpi=150)\n",
" plt.figure(dpi=150)\n",
" print(f'Loading picture ID: {flnm}')\n",
" ax = plt.subplot(121), plt.imshow(mpimg.imread(f'{path_binary}/{flnm}'))\n",
" ax = plt.subplot(122), plt.imshow(mpimg.imread(f'{path_photo}/{flnm}'))\n",
" plt.subplot(121), plt.imshow(mpimg.imread(f'{path_binary}/{flnm}'))\n",
" plt.subplot(122), plt.imshow(mpimg.imread(f'{path_photo}/{flnm}'))\n",
" plt.grid(False)\n",
" plt.show()"
]
Expand Down
1 change: 0 additions & 1 deletion test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,3 @@ def test_image(self) -> None:
# Now exporters must be closed
self.assertEqual(len(image_binary.get_images()), 0)
self.assertEqual(len(image_photo.get_images()), 0)
image_binary.restore_plot()

0 comments on commit 542ed75

Please sign in to comment.