Skip to content

Commit

Permalink
CensorAndScale for continious
Browse files Browse the repository at this point in the history
  • Loading branch information
GondekNP committed Jan 11, 2024
1 parent 29eabac commit ef66df2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 37 deletions.
6 changes: 3 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def analyze_burn(body: AnaylzeBurnPOSTBody, sftp_client: SFTPClient = Depends(ge
logger.log_text(f"Error: {e}")
return f"Error: {e}", 400

@app.get("/map/{fire_event_name}/{burn_metric}", response_class=HTMLResponse)
def serve_map(request: Request, fire_event_name: str, burn_metric: str, manifest: dict = Depends(get_manifest)):
@app.get("/map/{affiliation}/{fire_event_name}/{burn_metric}", response_class=HTMLResponse)
def serve_map(request: Request, fire_event_name: str, burn_metric: str, affiliation: str, manifest: dict = Depends(get_manifest)):
mapbox_token = get_mapbox_secret()

tileserver_endpoint = 'https://tf-rest-burn-severity-ohi6r6qs2a-uc.a.run.app'
# tileserver_endpoint = 'http://localhost:5050'
cog_url = f"https://burn-severity-backend.s3.us-east-2.amazonaws.com/public/{fire_event_name}/{burn_metric}.tif"
cog_url = f"https://burn-severity-backend.s3.us-east-2.amazonaws.com/public/{affiliation}/{fire_event_name}/{burn_metric}.tif"
cog_tileserver_url_prefix = tileserver_endpoint + f"/cog/tiles/WebMercatorQuad/{{z}}/{{x}}/{{y}}.png?url={cog_url}&nodata=-99&return_mask=true"

fire_metadata = manifest[fire_event_name]
Expand Down
22 changes: 0 additions & 22 deletions src/lib/burn_severity.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,3 @@ def classify_burn(array, thresholds):
reclass = xr.where((array < threshold) & (reclass.isnull()), value, reclass)

return reclass

def is_s3_url_valid(url):
"""
This function checks if an S3 URL is valid
"""
s3 = boto3.client('s3')
s3.meta.events.register('choose-signer.s3.*', disable_signing)

bucket_name = url.split('/')[2]
key = '/'.join(url.split('/')[3:])
try:
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=key)
for obj in response.get('Contents', []):
if obj['Key'] == key:
return True
return False
except NoCredentialsError:
print("No AWS credentials found")
return False
except Exception as e:
print(f"Invalid S3 URL: {url}. Exception: {str(e)}")
return False
56 changes: 45 additions & 11 deletions src/lib/titiler_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
from rio_tiler.models import ImageData
import numpy as np

def convert_to_rgb(classified: np.ndarray, mask: np.ndarray) -> np.ndarray:
# Convert to a red rgb image, from grayscale
r_channel = np.full_like(classified, 255)
classified_rgb = np.stack(
[r_channel, classified, classified],
axis=0
)
rgb_mask = np.stack(
[mask, mask, mask],
axis=0
).squeeze()

final_img = np.ma.MaskedArray(classified_rgb, mask=rgb_mask)
return final_img


class Classify(BaseAlgorithm):

# Parameters
Expand All @@ -18,17 +34,7 @@ def __call__(self, img: ImageData) -> ImageData:

classified = np.select(threshold_checks, png_int_values).astype(np.uint8)

# Convert to a red rgb image, from grayscale
r_channel = np.full_like(classified, 255)
classified_rgb = np.stack(
[r_channel, classified, classified],
axis=0
)
rgb_mask = np.stack(
[mask, mask, mask],
axis=0
).squeeze()
final_img = np.ma.MaskedArray(classified_rgb, mask=rgb_mask)
final_img = convert_to_rgb(classified, mask)

# Create output ImageData
return ImageData(
Expand All @@ -38,7 +44,35 @@ def __call__(self, img: ImageData) -> ImageData:
bounds=img.bounds,
)

class CensorAndScale(BaseAlgorithm):
# Parameters
min: float
max: float

def __call__(self, img: ImageData) -> ImageData:
# Create masks for values below min and above max
mask_below = img.data < self.min
mask_above = img.data > self.max

# Create a mask for NaN values or values equal to -99
mask_transparent = (np.isnan(img.data)) | (img.data == -99)

# Set values below min to 0 and above max to 255
img.data[mask_below] = 0
img.data[mask_above] = 255

# Scale values between min and max to 0-255
mask_middle = ~mask_below & ~mask_above
img.data[mask_middle] = ((img.data[mask_middle] - self.min) / (self.max - self.min)) * 255

final_img = convert_to_rgb(img.data, mask_transparent)

return ImageData(
final_img,
assets=img.assets,
crs=img.crs,
bounds=img.bounds,
)

algorithms = default_algorithms.register(
{
Expand Down
13 changes: 12 additions & 1 deletion src/util/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import datetime
import rasterio
from rasterio.enums import Resampling

import geopandas as gpd
from google.cloud import logging as cloud_logging

class SFTPClient:
Expand Down Expand Up @@ -148,6 +148,17 @@ def upload_cogs(self, metrics_stack, fire_event_name, prefire_date_range, postfi
source_local_path=local_cog_path,
remote_path=f"{affiliation}/{fire_event_name}/{band_name}.tif",
)

# Upload the difference between dNBR and RBR
local_cog_path = os.path.join(tmpdir, f"pct_change_dnbr_rbr.tif")
pct_change = (metrics_stack.sel(burn_metric="rbr") - metrics_stack.sel(burn_metric="dnbr")) / \
metrics_stack.sel(burn_metric="dnbr") * 100
pct_change.rio.to_raster(local_cog_path, driver="GTiff")
self.upload(
source_local_path=local_cog_path,
remote_path=f"{affiliation}/{fire_event_name}/pct_change_dnbr_rbr.tif",
)


def update_manifest(self, fire_event_name, bounds, prefire_date_range, postfire_date_range, affiliation):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit ef66df2

Please sign in to comment.