diff --git a/geedim/utils.py b/geedim/utils.py index afd51f4..3b0b150 100644 --- a/geedim/utils.py +++ b/geedim/utils.py @@ -1,17 +1,17 @@ """ - Copyright 2021 Dugal Harris - dugalh@gmail.com +Copyright 2021 Dugal Harris - dugalh@gmail.com - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from __future__ import annotations @@ -39,7 +39,7 @@ from geedim.enums import ResamplingMethod -if '__file__' in globals(): +if "__file__" in globals(): root_path = pathlib.Path(__file__).absolute().parents[1] else: root_path = pathlib.Path(os.getcwd()) @@ -47,7 +47,9 @@ _GDAL_AT_LEAST_35 = GDALVersion.runtime().at_least("3.5") -def Initialize(opt_url: Optional[str] = 'https://earthengine-highvolume.googleapis.com', **kwargs): +def Initialize( + opt_url: Optional[str] = "https://earthengine-highvolume.googleapis.com", **kwargs +): """ Initialise Earth Engine. @@ -69,13 +71,17 @@ def Initialize(opt_url: Optional[str] = 'https://earthengine-highvolume.googleap if not ee.data._credentials: # Adpated from https://gis.stackexchange.com/questions/380664/how-to-de-authenticate-from-earth-engine-api. - env_key = 'EE_SERVICE_ACC_PRIVATE_KEY' + env_key = "EE_SERVICE_ACC_PRIVATE_KEY" if env_key in os.environ: # authenticate with service account key_dict = json.loads(os.environ[env_key]) - credentials = ee.ServiceAccountCredentials(key_dict['client_email'], key_data=key_dict['private_key']) - ee.Initialize(credentials, opt_url=opt_url, **kwargs) + credentials = ee.ServiceAccountCredentials( + key_dict["client_email"], key_data=key_dict["private_key"] + ) + ee.Initialize( + credentials, opt_url=opt_url, project=key_dict["project_id"], **kwargs + ) else: ee.Initialize(opt_url=opt_url, **kwargs) @@ -121,11 +127,11 @@ def suppress_rio_logs(level: int = logging.ERROR): try: # GEE sets GeoTIFF `colorinterp` tags incorrectly. This suppresses `rasterio` warning relating to this: # 'Sum of Photometric type-related color channels and ExtraSamples doesn't match SamplesPerPixel' - rio_level = logging.getLogger('rasterio').getEffectiveLevel() - logging.getLogger('rasterio').setLevel(level) + rio_level = logging.getLogger("rasterio").getEffectiveLevel() + logging.getLogger("rasterio").setLevel(level) yield finally: - logging.getLogger('rasterio').setLevel(rio_level) + logging.getLogger("rasterio").setLevel(rio_level) def get_bounds(filename: pathlib.Path, expand: float = 5): @@ -150,7 +156,10 @@ def get_bounds(filename: pathlib.Path, expand: float = 5): expand_x = (bbox.right - bbox.left) * expand / 100.0 expand_y = (bbox.top - bbox.bottom) * expand / 100.0 bbox_expand = rio.coords.BoundingBox( - bbox.left - expand_x, bbox.bottom - expand_y, bbox.right + expand_x, bbox.top + expand_y + bbox.left - expand_x, + bbox.bottom - expand_y, + bbox.right + expand_x, + bbox.top + expand_y, ) else: bbox_expand = bbox @@ -164,7 +173,9 @@ def get_bounds(filename: pathlib.Path, expand: float = 5): ] bbox_expand_dict = dict(type="Polygon", coordinates=[coordinates]) - src_bbox_wgs84 = warp.transform_geom(im.crs, "WGS84", bbox_expand_dict) # convert to WGS84 geojson + src_bbox_wgs84 = warp.transform_geom( + im.crs, "WGS84", bbox_expand_dict + ) # convert to WGS84 geojson return src_bbox_wgs84 @@ -187,10 +198,12 @@ def get_projection(image: ee.Image, min_scale: bool = True) -> ee.Projection: Requested projection. """ if not isinstance(image, ee.Image): - raise TypeError('image is not an instance of ee.Image') + raise TypeError("image is not an instance of ee.Image") bands = image.bandNames() - scales = bands.map(lambda band: image.select(ee.String(band)).projection().nominalScale()) + scales = bands.map( + lambda band: image.select(ee.String(band)).projection().nominalScale() + ) projs = bands.map(lambda band: image.select(ee.String(band)).projection()) projs = projs.sort(scales) @@ -198,8 +211,7 @@ def get_projection(image: ee.Image, min_scale: bool = True) -> ee.Projection: class Spinner(Thread): - - def __init__(self, label='', interval=0.2, leave=True, **kwargs): + def __init__(self, label="", interval=0.2, leave=True, **kwargs): """ Thread sub-class to run a non-blocking spinner. @@ -234,20 +246,22 @@ def __exit__(self, exc_type, exc_val, exc_tb): def run(self): """Run the spinner thread.""" - cursors_it = itertools.cycle('/-\|') + cursors_it = itertools.cycle("/-\|") while self._run: cursor = next(cursors_it) - tqdm.write('\r' + self._label + cursor, file=self._file, end='') + tqdm.write("\r" + self._label + cursor, file=self._file, end="") self._file.flush() time.sleep(self._interval) if self._leave == True: - tqdm.write('', file=self._file, end='\n') + tqdm.write("", file=self._file, end="\n") elif self._leave == False: - tqdm.write('\r', file=self._file, end='') + tqdm.write("\r", file=self._file, end="") elif isinstance(self._leave, str): - tqdm.write('\r' + self._label + self._leave + ' ', file=self._file, end='\n') + tqdm.write( + "\r" + self._label + self._leave + " ", file=self._file, end="\n" + ) self._file.flush() def start(self): @@ -289,7 +303,12 @@ def resample(ee_image: ee.Image, method: ResamplingMethod) -> ee.Image: # resample the image, if it has a fixed projection proj = get_projection(ee_image, min_scale=True) - has_fixed_proj = proj.crs().compareTo('EPSG:4326').neq(0).Or(proj.nominalScale().toInt64().neq(111319)) + has_fixed_proj = ( + proj.crs() + .compareTo("EPSG:4326") + .neq(0) + .Or(proj.nominalScale().toInt64().neq(111319)) + ) def _resample(ee_image: ee.Image) -> ee.Image: """Resample the given image, allowing for additional 'average' method.""" @@ -313,15 +332,21 @@ def retry_session( """requests session configured for retries.""" session = session or requests.Session() retry = Retry( - total=retries, read=retries, connect=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist + total=retries, + read=retries, + connect=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, ) adapter = HTTPAdapter(max_retries=retry) - session.mount('http://', adapter) - session.mount('https://', adapter) + session.mount("http://", adapter) + session.mount("https://", adapter) return session -def expand_window_to_grid(win: Window, expand_pixels: Tuple[int, int] = (0, 0)) -> Window: +def expand_window_to_grid( + win: Window, expand_pixels: Tuple[int, int] = (0, 0) +) -> Window: """ Expand rasterio window extents to the nearest whole numbers i.e. for ``expand_pixels`` >= (0, 0), it will return a window that contains the original extents. @@ -348,7 +373,7 @@ def expand_window_to_grid(win: Window, expand_pixels: Tuple[int, int] = (0, 0)) def rio_crs(crs: str | rio.CRS) -> str | rio.CRS: """Convert a GEE CRS string to a rasterio compatible CRS string.""" - if crs == 'SR-ORG:6974': + if crs == "SR-ORG:6974": # This is a workaround for https://issuetracker.google.com/issues/194561313, that replaces the alleged GEE # SR-ORG:6974 with actual WKT for SR-ORG:6842 taken from # https://github.com/OSGeo/spatialreference.org/blob/master/scripts/sr-org.json. @@ -378,8 +403,8 @@ def asset_id(image_id: str, folder: str = None): """ if not folder: return image_id - im_name = image_id.replace('/', '-') + im_name = image_id.replace("/", "-") folder = pathlib.PurePosixPath(folder) cloud_folder = pathlib.PurePosixPath(folder.parts[0]) - asset_path = pathlib.PurePosixPath('/'.join(folder.parts[1:])).joinpath(im_name) - return f'projects/{str(cloud_folder)}/assets/{str(asset_path)}' + asset_path = pathlib.PurePosixPath("/".join(folder.parts[1:])).joinpath(im_name) + return f"projects/{str(cloud_folder)}/assets/{str(asset_path)}"