Skip to content

Commit

Permalink
Fix restore plot
Browse files Browse the repository at this point in the history
  • Loading branch information
ppizarror committed Aug 13, 2024
1 parent 542ed75 commit 04de40e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion MLStructFP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__description__ = 'Machine learning structural floor plan dataset'
__keywords__ = ['ml', 'ai', 'floor-plan', 'architectural', 'dataset', 'cnn']
__email__ = '[email protected]'
__version__ = '0.6.1'
__version__ = '0.6.2'

# URL
__url__ = 'https://github.com/MLSTRUCT/MLSTRUCT-FP'
Expand Down
5 changes: 5 additions & 0 deletions MLStructFP/db/image/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseImage(ABC):
"""
_image_size: int
_images: List['np.ndarray'] # List of stored images during make_region
_last_make_region_time: float # Total time for last make region
_names: List[str]
_path: str
_save_images: bool
Expand Down Expand Up @@ -98,6 +99,10 @@ def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax
"""
raise NotImplementedError()

@property
def make_region_last_time(self) -> float:
return self._last_make_region_time

def export(self, filename: str, close: bool = True, compressed: bool = True) -> None:
"""
Export saved images to numpy format and then removes all data.
Expand Down
11 changes: 8 additions & 3 deletions MLStructFP/db/image/_rect_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import matplotlib.pyplot as plt
import numpy as np
import os
import time

if TYPE_CHECKING:
from MLStructFP.db._c_rect import Rect
Expand Down Expand Up @@ -74,7 +75,7 @@ def init(self) -> 'RectBinaryImage':
"""
plt.switch_backend('agg')
self._initialized = True
self.close()
self.close(restore_plot=False)
self._initialized = True
return self

Expand Down Expand Up @@ -168,6 +169,7 @@ def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax
"""
if not self._initialized:
raise RuntimeError('Exporter not initialized, use .init()')
t0 = time.time()
store_matplotlib_figure = not HIGHLIGHT_RECT
fig, ax = self._get_floor_plot(floor, rect, store=store_matplotlib_figure)

Expand Down Expand Up @@ -226,12 +228,15 @@ def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax
del im, im2, im3, im4, fig, ax

# Returns the image index on the library array
self._last_make_region_time = time.time() - t0
return len(self._images) - 1, array

def close(self) -> None:
def close(self, restore_plot: bool = True) -> None:
"""
Close and delete all generated figures.
This function also restores plot engine.
:param restore_plot: Restores plotting engine
"""
if not self._initialized:
raise RuntimeError('Exporter not initialized, it cannot be closed')
Expand All @@ -247,7 +252,7 @@ def close(self) -> None:
self._names.clear()

# Restore plot
if plt.get_backend() == 'agg':
if restore_plot and plt.get_backend() == 'agg':
plt.switch_backend(INITIAL_BACKEND)

self._initialized = False
5 changes: 4 additions & 1 deletion MLStructFP/db/image/_rect_photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,12 @@ def make_region(self, xmin: NumberType, xmax: NumberType, ymin: NumberType, ymax
:return: Returns the image index on the library array
"""
assert xmax > xmin and ymax > ymin
t0 = time.time()
dx = (xmax - xmin) / 2
dy = (ymax - ymin) / 2
return self._make(floor, GeomPoint2D(xmin + dx, ymin + dy), dx, dy, rect)
mk = self._make(floor, GeomPoint2D(xmin + dx, ymin + dy), dx, dy, rect)
self._last_make_region_time = time.time() - t0
return mk

def _make(self, floor: 'Floor', cr: 'GeomPoint2D', dx: float, dy: float, rect: Optional['Rect']
) -> Tuple[int, 'np.ndarray']:
Expand Down

0 comments on commit 04de40e

Please sign in to comment.