Skip to content

Commit

Permalink
add project kwarg to ee.Initialize()when called from github action
Browse files Browse the repository at this point in the history
  • Loading branch information
dugalh committed Feb 3, 2025
1 parent 6d8507e commit a5054f5
Showing 1 changed file with 63 additions and 38 deletions.
101 changes: 63 additions & 38 deletions geedim/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
Copyright 2021 Dugal Harris - [email protected]
Copyright 2021 Dugal Harris - [email protected]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import annotations
Expand Down Expand Up @@ -39,15 +39,17 @@

from geedim.enums import ResamplingMethod

if '__file__' in globals():
if "__file__" in globals():
root_path = pathlib.Path(__file__).absolute().parents[1]
else:
root_path = pathlib.Path(os.getcwd())

_GDAL_AT_LEAST_35 = GDALVersion.runtime().at_least("3.5")


def Initialize(opt_url: Optional[str] = 'https://earthengine-highvolume.googleapis.com', **kwargs):
def Initialize(
opt_url: Optional[str] = "https://earthengine-highvolume.googleapis.com", **kwargs
):
"""
Initialise Earth Engine.
Expand All @@ -69,13 +71,17 @@ def Initialize(opt_url: Optional[str] = 'https://earthengine-highvolume.googleap

if not ee.data._credentials:
# Adpated from https://gis.stackexchange.com/questions/380664/how-to-de-authenticate-from-earth-engine-api.
env_key = 'EE_SERVICE_ACC_PRIVATE_KEY'
env_key = "EE_SERVICE_ACC_PRIVATE_KEY"

if env_key in os.environ:
# authenticate with service account
key_dict = json.loads(os.environ[env_key])
credentials = ee.ServiceAccountCredentials(key_dict['client_email'], key_data=key_dict['private_key'])
ee.Initialize(credentials, opt_url=opt_url, **kwargs)
credentials = ee.ServiceAccountCredentials(
key_dict["client_email"], key_data=key_dict["private_key"]
)
ee.Initialize(
credentials, opt_url=opt_url, project=key_dict["project_id"], **kwargs
)
else:
ee.Initialize(opt_url=opt_url, **kwargs)

Expand Down Expand Up @@ -121,11 +127,11 @@ def suppress_rio_logs(level: int = logging.ERROR):
try:
# GEE sets GeoTIFF `colorinterp` tags incorrectly. This suppresses `rasterio` warning relating to this:
# 'Sum of Photometric type-related color channels and ExtraSamples doesn't match SamplesPerPixel'
rio_level = logging.getLogger('rasterio').getEffectiveLevel()
logging.getLogger('rasterio').setLevel(level)
rio_level = logging.getLogger("rasterio").getEffectiveLevel()
logging.getLogger("rasterio").setLevel(level)
yield
finally:
logging.getLogger('rasterio').setLevel(rio_level)
logging.getLogger("rasterio").setLevel(rio_level)


def get_bounds(filename: pathlib.Path, expand: float = 5):
Expand All @@ -150,7 +156,10 @@ def get_bounds(filename: pathlib.Path, expand: float = 5):
expand_x = (bbox.right - bbox.left) * expand / 100.0
expand_y = (bbox.top - bbox.bottom) * expand / 100.0
bbox_expand = rio.coords.BoundingBox(
bbox.left - expand_x, bbox.bottom - expand_y, bbox.right + expand_x, bbox.top + expand_y
bbox.left - expand_x,
bbox.bottom - expand_y,
bbox.right + expand_x,
bbox.top + expand_y,
)
else:
bbox_expand = bbox
Expand All @@ -164,7 +173,9 @@ def get_bounds(filename: pathlib.Path, expand: float = 5):
]

bbox_expand_dict = dict(type="Polygon", coordinates=[coordinates])
src_bbox_wgs84 = warp.transform_geom(im.crs, "WGS84", bbox_expand_dict) # convert to WGS84 geojson
src_bbox_wgs84 = warp.transform_geom(
im.crs, "WGS84", bbox_expand_dict
) # convert to WGS84 geojson
return src_bbox_wgs84


Expand All @@ -187,19 +198,20 @@ def get_projection(image: ee.Image, min_scale: bool = True) -> ee.Projection:
Requested projection.
"""
if not isinstance(image, ee.Image):
raise TypeError('image is not an instance of ee.Image')
raise TypeError("image is not an instance of ee.Image")

bands = image.bandNames()
scales = bands.map(lambda band: image.select(ee.String(band)).projection().nominalScale())
scales = bands.map(
lambda band: image.select(ee.String(band)).projection().nominalScale()
)
projs = bands.map(lambda band: image.select(ee.String(band)).projection())
projs = projs.sort(scales)

return ee.Projection(projs.get(0) if min_scale else projs.get(-1))


class Spinner(Thread):

def __init__(self, label='', interval=0.2, leave=True, **kwargs):
def __init__(self, label="", interval=0.2, leave=True, **kwargs):
"""
Thread sub-class to run a non-blocking spinner.
Expand Down Expand Up @@ -234,20 +246,22 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def run(self):
"""Run the spinner thread."""
cursors_it = itertools.cycle('/-\|')
cursors_it = itertools.cycle("/-\|")

while self._run:
cursor = next(cursors_it)
tqdm.write('\r' + self._label + cursor, file=self._file, end='')
tqdm.write("\r" + self._label + cursor, file=self._file, end="")
self._file.flush()
time.sleep(self._interval)

if self._leave == True:
tqdm.write('', file=self._file, end='\n')
tqdm.write("", file=self._file, end="\n")
elif self._leave == False:
tqdm.write('\r', file=self._file, end='')
tqdm.write("\r", file=self._file, end="")
elif isinstance(self._leave, str):
tqdm.write('\r' + self._label + self._leave + ' ', file=self._file, end='\n')
tqdm.write(
"\r" + self._label + self._leave + " ", file=self._file, end="\n"
)
self._file.flush()

def start(self):
Expand Down Expand Up @@ -289,7 +303,12 @@ def resample(ee_image: ee.Image, method: ResamplingMethod) -> ee.Image:

# resample the image, if it has a fixed projection
proj = get_projection(ee_image, min_scale=True)
has_fixed_proj = proj.crs().compareTo('EPSG:4326').neq(0).Or(proj.nominalScale().toInt64().neq(111319))
has_fixed_proj = (
proj.crs()
.compareTo("EPSG:4326")
.neq(0)
.Or(proj.nominalScale().toInt64().neq(111319))
)

def _resample(ee_image: ee.Image) -> ee.Image:
"""Resample the given image, allowing for additional 'average' method."""
Expand All @@ -313,15 +332,21 @@ def retry_session(
"""requests session configured for retries."""
session = session or requests.Session()
retry = Retry(
total=retries, read=retries, connect=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist
total=retries,
read=retries,
connect=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session


def expand_window_to_grid(win: Window, expand_pixels: Tuple[int, int] = (0, 0)) -> Window:
def expand_window_to_grid(
win: Window, expand_pixels: Tuple[int, int] = (0, 0)
) -> Window:
"""
Expand rasterio window extents to the nearest whole numbers i.e. for ``expand_pixels`` >= (0, 0), it will return a
window that contains the original extents.
Expand All @@ -348,7 +373,7 @@ def expand_window_to_grid(win: Window, expand_pixels: Tuple[int, int] = (0, 0))

def rio_crs(crs: str | rio.CRS) -> str | rio.CRS:
"""Convert a GEE CRS string to a rasterio compatible CRS string."""
if crs == 'SR-ORG:6974':
if crs == "SR-ORG:6974":
# This is a workaround for https://issuetracker.google.com/issues/194561313, that replaces the alleged GEE
# SR-ORG:6974 with actual WKT for SR-ORG:6842 taken from
# https://github.com/OSGeo/spatialreference.org/blob/master/scripts/sr-org.json.
Expand Down Expand Up @@ -378,8 +403,8 @@ def asset_id(image_id: str, folder: str = None):
"""
if not folder:
return image_id
im_name = image_id.replace('/', '-')
im_name = image_id.replace("/", "-")
folder = pathlib.PurePosixPath(folder)
cloud_folder = pathlib.PurePosixPath(folder.parts[0])
asset_path = pathlib.PurePosixPath('/'.join(folder.parts[1:])).joinpath(im_name)
return f'projects/{str(cloud_folder)}/assets/{str(asset_path)}'
asset_path = pathlib.PurePosixPath("/".join(folder.parts[1:])).joinpath(im_name)
return f"projects/{str(cloud_folder)}/assets/{str(asset_path)}"

0 comments on commit a5054f5

Please sign in to comment.