Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image2equi7grid parallel processing #46

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ __pycache__/*
.pydevproject
.settings
.idea
.vscode/

# Package files
*.egg
Expand Down
4 changes: 2 additions & 2 deletions src/equi7grid/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "0.2.4.dev0+g2f2e096"
__commit__ = "2f2e096"
__version__ = "0.2.4.dev9+gdd4771b"
__commit__ = "dd4771b"
284 changes: 283 additions & 1 deletion src/equi7grid/image2equi7grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
spatial reference, using the Equi7TilingSystem() for the file tiling.
"""

from functools import partial
import os
import subprocess
from datetime import datetime
Expand All @@ -43,6 +44,8 @@
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
import multiprocessing as mp
from equi7grid.equi7grid import Equi7Grid

# dict for transfer the datatype and resample type
gdal_datatype = {
Expand Down Expand Up @@ -453,6 +456,285 @@ def image2equi7grid(e7grid,
return dst_file_names


class TileResampler(object):

def __init__(self,
e7grid,
image,
output_dir,
gdal_path=None,
inband=None,
subgrid_ids=None,
accurate_boundary=True,
e7_folder=True,
ftiles=None,
coverland=True,
roi=None,
naming_convention=None,
compress_type="LZW",
resampling_type="bilinear",
subfolder=None,
overwrite=False,
data_type=None,
image_nodata=None,
tile_nodata=None,
scale=None,
offset=None,
tiled=True,
blocksize=512,
parallelism=1):

self.e7grid = e7grid
self.outdir = output_dir
self.subgrid_ids = subgrid_ids
self.image = image
self.roi = roi
self.coverland = coverland
self.inband = inband
self.scale = scale
self.offset = offset
self.ftiles = ftiles
self.image_nodata = image_nodata
self.tile_nodata = tile_nodata if tile_nodata else image_nodata
self.naming_convention = naming_convention
self.resampling_type=resampling_type
self.compress_type=compress_type
self.data_type = data_type
self.overwrite = overwrite
self.tiled = tiled
self.blocksize=blocksize
self.e7_folder = e7_folder
self.subfolder = subfolder
self.accurate_boundary = accurate_boundary
self.parallelism = parallelism

# get gdal_dir for compatibility
self.gdal_path = gdal_path

if self.ftiles is None:
self.ftiles = self.find_overlapping_tiles()
else:
if type(ftiles) != list:
self.ftiles = [ftiles]


def set_image(self, image, image_nodata=None):
self.image = image
self.image_nodata = image_nodata

def find_overlapping_tiles(self):
if self.roi is None:
if self.accurate_boundary:
try:
geo_extent = retrieve_raster_boundary(
self.image,
gdal_path=self.gdal_path,
nodata=self.image_nodata)
except Exception as e:
print("retrieve_raster_boundary failed:", str(e))
geo_extent = None
else:
geo_extent = None
if geo_extent:
ftiles = self.e7grid.search_tiles_in_roi(
roi_geometry=geo_extent,
subgrid_ids=self.subgrid_ids,
coverland=self.coverland)
else:
ds = open_image(self.image)
img_extent = ds.get_extent()
bbox = (img_extent[0:2], img_extent[2:4])
img_spref = osr.SpatialReference()
img_spref.ImportFromWkt(ds.projection())
ftiles = self.e7grid.search_tiles_in_roi(
bbox=bbox,
subgrid_ids=self.subgrid_ids,
osr_spref=img_spref,
coverland=self.coverland)
else:
ftiles = self.e7grid.search_tiles_in_roi(
roi_geometry=self.roi,
subgrid_ids=self.subgrid_ids,
coverland=self.coverland)

return ftiles


def prepare_resampling_args(self):
if self.naming_convention is not None:
self.naming_convention["grid_name"] = "{grid_name}"
self.naming_convention["tile_name"] = "{tile_name}"
file_naming = str(self.naming_convention)
else:
file_naming = None


args = {"image": self.image,
"output_dir": self.outdir,
"inband": self.inband,
"gdal_path": self.gdal_path,
"image_nodata": self.image_nodata,
"tile_nodata": self.tile_nodata,
"resampling_type": self.resampling_type,
"compress_type" : self.compress_type,
"naming_convention": file_naming,
"data_type": self.data_type,
"e7_folder": self.e7_folder,
"subfolder": self.subfolder,
"overwrite": self.overwrite,
"tiled": self.tiled,
"blocksize": self.blocksize,
"scale": self.scale,
"offset": self.offset,
}

return args


def resample_tiles(self):
args = self.prepare_resampling_args()

dst_file_names = []

if self.parallelism == 1:
for ftile in self.ftiles:
result = resample_tile(ftile, **args)
if result:
dst_file_names.append(result)
else:
n_tiles = len(self.ftiles)

num_cpu = self.parallelism if n_tiles > self.parallelism else n_tiles

with mp.Pool(processes=num_cpu) as pool:
func = partial(resample_tile, **args)
results = pool.map(func, self.ftiles)
pool.close()
pool.join()

dst_file_names = [x for x in results if x]

return dst_file_names


def resample_tile(ftile,
image=None,
output_dir=None,
inband=None,
gdal_path=None,
image_nodata=None,
tile_nodata=None,
resampling_type="biliear",
compress_type="LZW",
naming_convention=None,
data_type=None,
e7_folder=None,
subfolder=None,
overwrite=False,
tiled=False,
blocksize=None,
scale=None,
offset=None):

sampling = int(ftile[2:5])
e7grid = Equi7Grid(sampling)

# create grid folder
if e7_folder:
grid_folder = "EQUI7_{}".format(ftile[0:6])
tile_path = os.path.join(output_dir, grid_folder, ftile[7:])
if not os.path.exists(tile_path):
os.makedirs(tile_path)
else:
tile_path = output_dir

# make output filename
if naming_convention is None:
out_filename = os.path.splitext(os.path.basename(image))[0]
out_filename = "_".join((out_filename, ftile + ".tif"))
else:
try:
out_filename = naming_convention.format(
grid_name=ftile.split('_')[0], tile_name=ftile.split('_')[1])
except KeyError:
err_msg = "File naming convention does not contain 'grid_name' or 'tile_name'."
raise KeyError(err_msg)
if subfolder:
out_filepath = os.path.join(tile_path, subfolder, out_filename)
else:
out_filepath = os.path.join(tile_path, out_filename)

# using gdalwarp to resample
bbox = e7grid.get_tile_bbox_proj(ftile)
tile_project = '"{}"'.format(
e7grid.subgrids[ftile[0:2]].core.projection.proj4)

# prepare options for gdalwarp
options = {
'-t_srs': tile_project,
'-of': 'GTiff',
'-r': resampling_type,
'-te': " ".join(map(str, bbox)),
'-tr': "{} -{}".format(e7grid.core.sampling, e7grid.core.sampling)
}

options["-co"] = list()
if compress_type is not None:
options["-co"].append("COMPRESS={0}".format(compress_type))
if image_nodata != None:
options["-srcnodata"] = image_nodata
if tile_nodata != None:
options["-dstnodata"] = tile_nodata
if data_type != None:
options["-ot"] = data_type
options["-wt"] = data_type # test if this is what we want
if overwrite:
options["-overwrite"] = " "
if tiled: # tiled, square blocks
options["-co"].append("TILED=YES")
options["-co"].append("BLOCKXSIZE={0}".format(blocksize))
options["-co"].append("BLOCKYSIZE={0}".format(blocksize))
else: # stripped blocks
blockxsize = e7grid.core.tile_xsize_m // e7grid.core.sampling
blockysize = blocksize
options["-co"].append("TILED=NO")
options["-co"].append("BLOCKXSIZE={0}".format(blockxsize))
options["-co"].append("BLOCKYSIZE={0}".format(blockysize))

# call gdalwarp for resampling
succeed, _ = call_gdal_util('gdalwarp',
src_files=image,
src_band=inband,
dst_file=out_filepath,
gdal_path=gdal_path,
options=options)


if scale is not None and offset is not None:
# prepare options for gdal_translate
options = {'-a_scale': scale, '-a_offset': offset}
options["-co"] = list()
if tile_nodata != None:
options["-a_nodata "] = tile_nodata
if compress_type is not None:
options["-co"].append("COMPRESS={0}".format(compress_type))
if blocksize is not None:
options["-co"].append("TILED=YES")
options["-co"].append("BLOCKXSIZE={0}".format(blocksize))
options["-co"].append("BLOCKYSIZE={0}".format(blocksize))

succeed, _ = call_gdal_util('gdal_translate',
src_files=out_filepath,
dst_file=out_filepath,
gdal_path=gdal_path,
options=options)

if succeed:
return out_filepath
else:
return None


def open_image(filename):
""" open an image file

Expand Down Expand Up @@ -673,7 +955,7 @@ def retrieve_raster_boundary(infile,
# morphologic dilation
pixels = 3
struct = ndimage.generate_binary_structure(2, 2)
struct = ndimage.morphology.iterate_structure(struct, pixels)
struct = ndimage.iterate_structure(struct, pixels)
new_mask = ndimage.binary_dilation(mask, structure=struct)
src_arr = np.zeros_like(mask, dtype=np.uint8)
src_arr[new_mask] = 1
Expand Down
30 changes: 29 additions & 1 deletion tests/test_approve_image2equi7grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from approvaltests.namer import NamerFactory

from equi7grid.equi7grid import Equi7Grid
from equi7grid.image2equi7grid import image2equi7grid
from equi7grid.image2equi7grid import TileResampler, image2equi7grid


@pytest.fixture
Expand All @@ -27,3 +27,31 @@ def test_approve_imag2equi7grid(input_dir, out_dir):
options=NamerFactory.with_parameters("E018N066T6"))
verify_file((out_dir / "EQUI7_EU100M/E072N030T6/lake_in_russia_lonlat_EU100M_E072N030T6.tif").as_posix(),
options=NamerFactory.with_parameters("E072N030T6"))


@pytest.mark.skipif(os.name == 'nt', reason="CI Windows has troubles creating directories")
def test_multiprocessing_10m(input_dir, out_dir):
# begin-snippet: image2equi7grid-example
input_file = input_dir / "lake_in_russia_lonlat.tif"

tile_resampler = TileResampler(Equi7Grid(10), input_file.as_posix(), out_dir.as_posix(), parallelism=12)
results = tile_resampler.resample_tiles()

print(results)

assert (out_dir / "EQUI7_AS010M/E021N069T1/lake_in_russia_lonlat_AS010M_E021N069T1.tif").exists()
assert (out_dir / "EQUI7_EU010M/E073N032T1/lake_in_russia_lonlat_EU010M_E073N032T1.tif").exists()


@pytest.mark.skipif(os.name == 'nt', reason="CI Windows has troubles creating directories")
def test_multiprocessing_20m(input_dir, out_dir):
# begin-snippet: image2equi7grid-example
input_file = input_dir / "lake_in_russia_lonlat.tif"

tile_resampler = TileResampler(Equi7Grid(20), input_file.as_posix(), out_dir.as_posix(), parallelism=12)
results = tile_resampler.resample_tiles()

print(results)

assert (out_dir / "EQUI7_EU020M/E072N030T3/lake_in_russia_lonlat_EU020M_E072N030T3.tif").exists()
assert (out_dir / "EQUI7_AS020M/E021N069T3/lake_in_russia_lonlat_AS020M_E021N069T3.tif").exists()