Skip to content

Commit

Permalink
WIP: Fix tempfile removal for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
SorooshMani-NOAA committed Nov 16, 2023
1 parent 26287ef commit 8068f94
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 48 deletions.
15 changes: 12 additions & 3 deletions ocsmesh/hfun/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import functools
import gc
import logging
import os
from multiprocessing import cpu_count, Pool
import operator
import pathlib
import tempfile
from time import time
from typing import Union, List, Callable, Optional, Iterable, Tuple
Expand Down Expand Up @@ -234,6 +236,11 @@ def __init__(self,
self._constraints = []


def __del__(self):
for _, memfile_path in self._xy_cache.items():
pathlib.Path(memfile_path).unlink()


def msh_t(
self,
window: Optional[rasterio.windows.Window] = None,
Expand Down Expand Up @@ -1296,15 +1303,17 @@ def get_xy_memcache(
transformer = Transformer.from_crs(
self.src.crs, dst_crs, always_xy=True)
# pylint: disable=R1732
tmpfile = tempfile.NamedTemporaryFile()
# tmpfile = tempfile.NamedTemporaryFile()
tmpfd, tmppath = tempfile.mkstemp()
xy = self.get_xy(window)
fp = np.memmap(tmpfile, dtype='float32', mode='w+', shape=xy.shape)
fp = np.memmap(tmppath, dtype='float32', mode='w+', shape=xy.shape)
os.close(tmpfd)
fp[:] = np.vstack(
transformer.transform(xy[:, 0], xy[:, 1])).T
_logger.info('Saving values to memcache...')
fp.flush()
_logger.info('Done!')
self._xy_cache[f'{window}{dst_crs}'] = tmpfile
self._xy_cache[f'{window}{dst_crs}'] = tmppath
return fp[:]

_logger.info('Loading values from memcache...')
Expand Down
8 changes: 5 additions & 3 deletions ocsmesh/ops/combine_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,17 @@ def run(self):

poly_files_coll = []
_logger.info(f"Number of processes: {nprocs}")
with tempfile.TemporaryDirectory(dir=out_dir) as temp_dir, \
tempfile.NamedTemporaryFile() as base_file:
with tempfile.TemporaryDirectory(dir=out_dir) as temp_dir:

tmpfd, tmppath = tempfile.mkstemp()
if base_mult_poly:
base_mesh_path = base_file.name
base_mesh_path = tmppath
self._multipolygon_to_disk(
base_mesh_path, base_mult_poly, fix=False)
else:
base_mesh_path = None
base_mult_poly = None
os.close(tmpfd)


_logger.info("Processing DEM priorities ...")
Expand Down Expand Up @@ -235,6 +236,7 @@ def run(self):
],
ignore_index=True
)
pathlib.Path(tmppath).unlink()


# The assumption is this returns polygon or multipolygon
Expand Down
42 changes: 33 additions & 9 deletions ocsmesh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pathlib
import tempfile
import warnings
import platform
from time import time
from contextlib import contextmanager, ExitStack
from typing import (
Expand Down Expand Up @@ -144,15 +145,23 @@ class TemporaryFile:
cleanup capabities on object destruction.
"""

def __set__(self, obj, val: tempfile.NamedTemporaryFile):
def __set__(self, obj, val: Optional[os.PathLike]):
tmpfile = obj.__dict__.get('tmpfile')
if tmpfile is not None:
obj._src = None
pathlib.Path(tmpfile).unlink()

obj.__dict__['tmpfile'] = val
obj._src = rasterio.open(val.name)
if val is None:
obj._src = None
else:
obj._src = rasterio.open(val)

def __get__(self, obj, objtype=None) -> pathlib.Path:
tmpfile = obj.__dict__.get('tmpfile')
if tmpfile is None:
return obj.path
return pathlib.Path(tmpfile.name)
return pathlib.Path(tmpfile)


class SourceRaster:
Expand All @@ -165,7 +174,10 @@ class SourceRaster:
opening it everytime need arises.
"""

def __set__(self, obj, val: rasterio.DatasetReader):
def __set__(self, obj, val: Optional[rasterio.DatasetReader]):
source = obj.__dict__.get('source')
if source is not None:
source.close()
obj.__dict__['source'] = val

def __get__(self, obj, objtype=None) -> rasterio.DatasetReader:
Expand Down Expand Up @@ -345,6 +357,9 @@ def __init__(
self._path = path
self._crs = crs

def __del__(self):
self._tmpfile = None

def __iter__(self, chunk_size: int = None, overlap: int = None):
for window in self.iter_windows(chunk_size, overlap):
yield window, self.get_window_bounds(window)
Expand Down Expand Up @@ -382,14 +397,15 @@ def modifying_raster(
no_except = False
try:
# pylint: disable=R1732
tmpfile = tempfile.NamedTemporaryFile(prefix=tmpdir)
# tmpfile = tempfile.NamedTemporaryFile(prefix=tmpdir, mode='w')
tmpfd, tmppath = tempfile.mkstemp(prefix=tmpdir)

new_meta = kwargs
# Flag to workaround cases where "src" is NOT set yet
if use_src_meta:
new_meta = self.src.meta.copy()
new_meta.update(**kwargs)
with rasterio.open(tmpfile.name, 'w', **new_meta) as dst:
with rasterio.open(tmppath, 'w', **new_meta) as dst:
if use_src_meta:
for i, desc in enumerate(self.src.descriptions):
dst.set_band_description(i+1, desc)
Expand All @@ -399,9 +415,12 @@ def modifying_raster(

finally:
if no_except:
# So that tmpfile is NOT destroyed when it locally
# goes out of scope
self._tmpfile = tmpfile
self._tmpfile = tmppath

# We don't need to keep the descriptor open, we kept it
# open # so that there's no race condition on the temp
# file up to now
os.close(tmpfd)



Expand Down Expand Up @@ -944,6 +963,8 @@ def average_filter(
# in other parts of the code. Thorough testing is needed for
# modifying the raster (e.g. hfun add_contour is affected)

if platform.system() == 'Windows':
raise ImplementationError('Not supported on Windows!')
bands = apply_on_bands
if bands is None:
bands = range(1, self.src.count + 1)
Expand Down Expand Up @@ -1002,6 +1023,9 @@ def generic_filter(self, function, **kwargs: Any) -> None:
None
"""

if platform.system() == 'Windows':
raise ImplementationError('Not supported on Windows!')

# TODO: Don't overwrite; add additoinal bands for filtered values

# NOTE: Adding new bands in this function can result in issues
Expand Down
13 changes: 7 additions & 6 deletions tests/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
)
TEST_FILE = os.path.join(tempfile.gettempdir(), 'test_dem.tif')
if not Path(TEST_FILE).exists():
with tempfile.NamedTemporaryFile() as tfp:
urllib.request.urlretrieve(tif_url, filename=tfp.name)
r = Raster(tfp.name)
r.resampling_method = Resampling.average
r.resample(scaling_factor=0.01)
r.save(TEST_FILE)
tmpfd, tmppath = tempfile.mkstemp()
urllib.request.urlretrieve(tif_url, filename=tmppath)
os.close(tmpfd)
r = Raster(tmppath)
r.resampling_method = Resampling.average
r.resample(scaling_factor=0.01)
r.save(TEST_FILE)
1 change: 0 additions & 1 deletion tests/api/hfun.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,6 @@ def setUp(self):
)
mesh = ocsmesh.Mesh(msh_t)
mesh.write(str(self.mesh1), format='grd', overwrite=False)
mesh.write('/tmp/ocsmesh/mytest2.2dm', format='2dm', overwrite=True)


def tearDown(self):
Expand Down
9 changes: 6 additions & 3 deletions tests/api/mesh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#! python
import os
import tempfile
import unittest
import warnings
Expand Down Expand Up @@ -279,9 +280,11 @@ def test_specified_boundary_order_withmerge(self):
def test_specify_boundary_on_imported_mesh_with_boundary(self):
self.mesh.boundaries.auto_generate()

with tempfile.NamedTemporaryFile(suffix='.grd') as fo:
self.mesh.write(fo.name, format='grd', overwrite=True)
imported_mesh = Mesh.open(fo.name)
tmpfd, tmppath = tempfile.mkstemp(suffix='.grd')
self.mesh.write(tmppath, format='grd', overwrite=True)
imported_mesh = Mesh.open(tmppath)
os.close(tmpfd)
os.unlink(tmppath)

bdry = imported_mesh.boundaries

Expand Down
39 changes: 28 additions & 11 deletions tests/api/raster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
import tempfile
import unittest
import platform
from pathlib import Path

import numpy as np
Expand All @@ -9,6 +10,10 @@
from ocsmesh.utils import raster_from_numpy


IS_WINDOWS = platform.system() == 'Windows'



class Raster(unittest.TestCase):
def setUp(self):
self.tdir = Path(tempfile.mkdtemp())
Expand Down Expand Up @@ -38,21 +43,33 @@ def tearDown(self):
shutil.rmtree(self.tdir)


@unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int')
def test_avg_filter_nomask(self):
rast = ocsmesh.Raster(self.rast1)
rast.average_filter(size=17)
self.assertTrue(np.all(rast.get_values() == 10))
try:
rast = ocsmesh.Raster(self.rast1)
rast.average_filter(size=17)
self.assertTrue(np.all(rast.get_values() == 10))
finally:
del rast


@unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int')
def test_avg_filter_masked_nanfill(self):
rast = ocsmesh.Raster(self.rast2)
rast.average_filter(size=17)
self.assertTrue(
np.all(rast.values[~np.isnan(rast.values)] == 10))
try:
rast = ocsmesh.Raster(self.rast2)
rast.average_filter(size=17)
self.assertTrue(
np.all(rast.values[~np.isnan(rast.values)] == 10))
finally:
del rast


@unittest.skipIf(IS_WINDOWS, 'Not supported due to LowLevelFunction int')
def test_avg_filter_masked_nonnanfill(self):
rast = ocsmesh.Raster(self.rast3)
rast.average_filter(size=17)
self.assertTrue(
np.all(rast.values[rast.values != rast.nodata] == 10))
try:
rast = ocsmesh.Raster(self.rast3)
rast.average_filter(size=17)
self.assertTrue(
np.all(rast.values[rast.values != rast.nodata] == 10))
finally:
del rast
Loading

0 comments on commit 8068f94

Please sign in to comment.