Skip to content

Commit

Permalink
Feature/38 set logic options and sane defaults (#39)
Browse files Browse the repository at this point in the history
* new filter_set_options argument; existing tests working
* location and fault set operations implemented + tested
* added filter_set_options
* WIP all tests OK; using solvis-store OK;
* update upstream solvis libs; remove monkeypatching;
* fix schema bug;
* Bump version: 0.8.1 → 0.8.2
  • Loading branch information
chrisbc authored Jul 18, 2023
1 parent 55e69b9 commit 1b9fef1
Show file tree
Hide file tree
Showing 16 changed files with 768 additions and 440 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.8.1
current_version = 0.8.2
commit = True
tag = True

Expand Down
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Changelog

## [0.8.3] - 2023-07-19
### Added
- new filter_set_options argument
- sane defaults for location_radius & fault_name set ops
- using solvis-store cache for fault_name filtering
### Changed
- updated upstream solvis libs
- removed monkeypatching for solvis/solvis-store

## [0.8.2] - 2023-07-04

### Changed
- added list support for corupture queries

## [0.8.1] - 2023-07-03

### Changed
- remove alpha from hexrgb color strings to improve geojson portability

Expand Down
694 changes: 366 additions & 328 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "solvis-graphql-api"
version = "0.8.1"
version = "0.8.2"
description = "Graphql API for analysis of opensha modular Inversion Solutions"
authors = ["Chris Chamberlain <[email protected]>"]
license = "AGPL3"
Expand All @@ -25,10 +25,12 @@ Flask-GraphQL = "^2.0.1"
graphene = "<3"
pyyaml = "^6.0"

solvis-store = {git = "https://github.com/GNS-Science/solvis-store.git", rev = "d188deca8b7c1319053430cf8ff4e9adbd8cb0fa"}
nzshm-model = "^0.3.0"
nzshm-common = "^0.6.0"
solvis = "^0.7.0"
# solvis = "^0.7.0"
solvis = "^0.8.1"
solvis-store = {git = "https://github.com/GNS-Science/solvis-store", rev = "main"}

matplotlib = "^3.7.1"
werkzeug = "^2.3.3"

Expand Down
2 changes: 1 addition & 1 deletion solvis_graphql_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__author__ = """GNS Science"""
__email__ = '[email protected]'
__version__ = '0.8.1'
__version__ = '0.8.2'
195 changes: 114 additions & 81 deletions solvis_graphql_api/composite_solution/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,35 @@
import time
from functools import lru_cache
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Set, Tuple, Union
from typing import Any, Callable, Iterable, Iterator, List, Set, Tuple, Union

import geopandas as gpd
import nzshm_model
import solvis
from nzshm_common.location.location import location_by_id
from solvis.inversion_solution.typing import InversionSolutionProtocol
from solvis_store.config import DEPLOYMENT_STAGE
from solvis_store.solvis_db_query import get_rupture_ids
from solvis_store.query import get_fault_name_rupture_ids, get_location_radius_rupture_ids

from .filter_set_logic_options import SetOperationEnum

log = logging.getLogger(__name__)

FAULT_SECTION_LIMIT = 1e4

RESOLVE_LOCATIONS_INTERNALLY = False if DEPLOYMENT_STAGE == 'TEST' else True
# we want to use the solvis-store cache normally, override this in testing
RESOLVE_LOCATIONS_INTERNALLY = False # if DEPLOYMENT_STAGE == 'TEST' else True


@lru_cache
def get_location_polygon(radius_km, lon, lat):
return solvis.geometry.circle_polygon(radius_m=radius_km * 1000, lon=lon, lat=lat)


# TODO: this is here temporarily until we can get solvis published (GHA problems)
@lru_cache
def parent_fault_names(
sol: InversionSolutionProtocol, sort: Union[None, Callable[[Iterable], List]] = sorted
) -> List[str]:
if sort:
return sort(sol.fault_sections.ParentName.unique())
return list(sol.fault_sections.ParentName.unique())
return solvis.parent_fault_names(sol)


@lru_cache
Expand All @@ -58,21 +57,20 @@ def get_composite_solution(model_id: str) -> solvis.CompositeSolution:
return solvis.CompositeSolution.from_archive(Path(COMPOSITE_ARCHIVE_PATH), slt)


def filter_dataframe_by_radius(fault_system_solution, dataframe, location_ids, radius_km):
log.info('filter_dataframe_by_radius: %s %s %s' % (fault_system_solution, radius_km, location_ids))
rupture_ids = set()
for loc_id in location_ids:
loc = location_by_id(loc_id)
# print("LOC:", loc)
polygon = get_location_polygon(radius_km=radius_km, lon=loc['longitude'], lat=loc['latitude'])
rupture_ids = rupture_ids.union(set(fault_system_solution.get_ruptures_intersecting(polygon)))
# print(fault_system, len(rupture_ids))
# print('rupture_ids', rupture_ids)
return dataframe[dataframe["Rupture Index"].isin(rupture_ids)]
def get_rupture_ids_for_fault_names_stored(
model_id: str, fault_system: str, fault_names: Iterable[str], filter_set_options
) -> Iterator[int]:
log.info('get_rupture_ids_for_fault_names_stored: %s %s %s' % (model_id, fault_system, fault_names))
filter_set_options_dict = dict(filter_set_options)
fss = get_fault_system_solution_for_model(model_id, fault_system)
ruptset_ids = list(set([branch.rupture_set_id for branch in fss.branches]))
assert len(ruptset_ids) == 1
rupture_set_id = ruptset_ids[0]
union = False if filter_set_options_dict["multiple_faults"] == SetOperationEnum.INTERSECTION else True
return get_fault_name_rupture_ids(rupture_set_id, fault_names, union)


def filter_dataframe_by_radius_stored(model_id, fault_system, df0, location_ids, radius_km, union) -> Iterator[int]:
log.info('filter_dataframe_by_radius_stored: %s %s %s %s' % (model_id, fault_system, radius_km, location_ids))
def get_fault_system_solution_for_model(model_id, fault_system):
current_model = nzshm_model.get_model_version(model_id)
slt = current_model.source_logic_tree()

Expand All @@ -84,20 +82,94 @@ def get_fss(slt, fault_system):
# check the solutions in a given fault system have the same rupture_set
fss = get_fss(slt, fault_system)
assert fss is not None
return fss


def get_rupture_ids_for_location_radius(
fault_system_solution, location_ids, radius_km, filter_set_options: Tuple[Any]
) -> Set[int]:
log.info('get_rupture_ids_for_location_radius: %s %s %s' % (fault_system_solution, radius_km, location_ids))
filter_set_options_dict = dict(filter_set_options)
first = True
rupture_ids: Set[int]
for loc_id in location_ids:
loc = location_by_id(loc_id)
# print("LOC:", loc)
polygon = get_location_polygon(radius_km=radius_km, lon=loc['longitude'], lat=loc['latitude'])
location_rupture_ids = set(fault_system_solution.get_ruptures_intersecting(polygon))

if first:
rupture_ids = location_rupture_ids
first = False
else:
log.debug(
'filter_set_options_dict["multiple_locations"] %s' % filter_set_options_dict["multiple_locations"]
)
if filter_set_options_dict["multiple_locations"] == SetOperationEnum.INTERSECTION:
rupture_ids = rupture_ids.intersection(location_rupture_ids)
elif filter_set_options_dict["multiple_locations"] == SetOperationEnum.UNION:
rupture_ids = rupture_ids.union(location_rupture_ids)
else:
raise ValueError("unsupported SetOperation")
return rupture_ids


def get_rupture_ids_for_location_radius_stored(
model_id: str, fault_system: str, location_ids: Iterable[str], radius_km: int, filter_set_options: Tuple[Any]
) -> Iterator[int]:
log.info(
'get_rupture_ids_for_location_radius_stored: %s %s %s %s' % (model_id, fault_system, radius_km, location_ids)
)
filter_set_options_dict = dict(filter_set_options)

fss = get_fault_system_solution_for_model(model_id, fault_system)
ruptset_ids = list(set([branch.rupture_set_id for branch in fss.branches]))
assert len(ruptset_ids) == 1
rupture_set_id = ruptset_ids[0]

union = False if filter_set_options_dict["multiple_faults"] == SetOperationEnum.INTERSECTION else True
print("filter_dataframe_by_radius_stored", radius_km)
return get_rupture_ids(rupture_set_id=rupture_set_id, locations=location_ids, radius=radius_km, union=union)
print("get_rupture_ids_for_location_radius_stored", radius_km)
return get_location_radius_rupture_ids(
rupture_set_id=rupture_set_id, locations=location_ids, radius=radius_km, union=union
)


@lru_cache
def get_rupture_ids_for_parent_fault(fault_system_solution: InversionSolutionProtocol, fault_name: str) -> Set[int]:
return set(fault_system_solution.get_ruptures_for_parent_fault(fault_name))


def get_rupture_ids_for_fault_names(
fault_system_solution, corupture_fault_names, filter_set_options: Tuple[Any]
) -> Set[int]:
filter_set_options_dict = dict(filter_set_options)
fss = fault_system_solution
first = True
rupture_ids: Set[int]
for fault_name in corupture_fault_names:
if fault_name not in parent_fault_names(fss):
raise ValueError("Invalid fault name: %s" % fault_name)
tic22 = time.perf_counter()
fault_rupture_ids = get_rupture_ids_for_parent_fault(fss, fault_name)
tic23 = time.perf_counter()
log.debug('fss.get_ruptures_for_parent_fault %s: %2.3f seconds' % (fault_name, (tic23 - tic22)))

if first:
rupture_ids = fault_rupture_ids
first = False
else:
log.debug('filter_set_options_dict["multiple_faults"] %s' % filter_set_options_dict["multiple_faults"])
if filter_set_options_dict["multiple_faults"] == SetOperationEnum.INTERSECTION:
rupture_ids = rupture_ids.intersection(fault_rupture_ids)
elif filter_set_options_dict["multiple_faults"] == SetOperationEnum.UNION:
rupture_ids = rupture_ids.union(fault_rupture_ids)
else:
raise ValueError("AWHAAA")

return rupture_ids


@lru_cache
def matched_rupture_sections_gdf(
model_id: str,
Expand All @@ -108,6 +180,7 @@ def matched_rupture_sections_gdf(
max_rate: float,
min_mag: float,
max_mag: float,
filter_set_options: Tuple[Any],
union: bool = False,
corupture_fault_names: Union[None, Tuple[str]] = None,
) -> gpd.GeoDataFrame:
Expand All @@ -116,6 +189,8 @@ def matched_rupture_sections_gdf(
return a dataframe of the matched ruptures.
"""
log.debug('matched_rupture_sections_gdf() filter_set_options: %s' % filter_set_options)

tic0 = time.perf_counter()
composite_solution = get_composite_solution(model_id)

Expand All @@ -136,22 +211,14 @@ def matched_rupture_sections_gdf(

# co-rupture filter
if corupture_fault_names and len(corupture_fault_names):
first = True
rupture_ids: Set[int]
for fault_name in corupture_fault_names:
if fault_name not in parent_fault_names(fss):
raise ValueError("Invalid fault name: %s" % fault_name)
tic22 = time.perf_counter()
fault_rupture_ids = get_rupture_ids_for_parent_fault(fss, fault_name)
tic23 = time.perf_counter()
log.debug('fss.get_ruptures_for_parent_fault %s: %2.3f seconds' % (fault_name, (tic23 - tic22)))

if first:
rupture_ids = fault_rupture_ids
first = False
else:
rupture_ids = rupture_ids.intersection(fault_rupture_ids)

if RESOLVE_LOCATIONS_INTERNALLY:
rupture_ids = set(get_rupture_ids_for_fault_names(fss, corupture_fault_names, filter_set_options))
else:
rupture_ids = set(
get_rupture_ids_for_fault_names_stored(
model_id, fault_system, corupture_fault_names, filter_set_options
)
)
df0 = df0[df0["Rupture Index"].isin(list(rupture_ids))]

tic3 = time.perf_counter()
Expand All @@ -160,14 +227,17 @@ def matched_rupture_sections_gdf(
# location filters
if location_ids is not None and len(location_ids):
if RESOLVE_LOCATIONS_INTERNALLY:
df0 = filter_dataframe_by_radius(fss, df0, location_ids, radius_km)
rupture_ids = set(get_rupture_ids_for_location_radius(fss, location_ids, radius_km, filter_set_options))
else:
rupture_ids = set(filter_dataframe_by_radius_stored(model_id, fault_system, df0, location_ids, radius_km, union))
df0 = df0[df0["Rupture Index"].isin(rupture_ids)]
rupture_ids = set(
get_rupture_ids_for_location_radius_stored(
model_id, fault_system, location_ids, radius_km, filter_set_options
)
)
df0 = df0[df0["Rupture Index"].isin(rupture_ids)]

tic4 = time.perf_counter()
log.debug('matched_rupture_sections_gdf(): time apply location filters: %2.3f seconds' % (tic4 - tic3))

return df0


Expand All @@ -181,6 +251,7 @@ def fault_section_aggregates_gdf(
max_rate: float,
min_mag: float,
max_mag: float,
filter_set_options: Tuple[Any],
union: bool = False,
trace_only: bool = False,
corupture_fault_names: Union[None, Tuple[str]] = None,
Expand All @@ -202,6 +273,7 @@ def fault_section_aggregates_gdf(
max_rate,
min_mag,
max_mag,
filter_set_options,
union,
corupture_fault_names,
)
Expand Down Expand Up @@ -241,42 +313,3 @@ def fault_section_aggregates_gdf(
raise ValueError("No fault sections satisfy the filter.")

return rupture_sections_gdf


# class ColourScaleNormalise(graphene.Enum):
# LOG = "log"
# LIN = "lin"


# COLOR_SCALE_NORMALISE_LOG = 'log' if os.getenv('COLOR_SCALE_NORMALISATION', '').upper() == 'LOG' else 'lin'


# class HexRgbValueMapping(graphene.ObjectType):
# levels = graphene.List(graphene.Float)
# hexrgbs = graphene.List(graphene.String)


# @lru_cache
# def get_normaliser(color_scale_vmax, color_scale_vmin, color_scale_normalise):
# if color_scale_normalise == ColourScaleNormalise.LOG:
# log.debug("resolve_hazard_map using LOG normalized colour scale")
# norm = mpl.colors.LogNorm(vmin=color_scale_vmin, vmax=color_scale_vmax)
# else:
# color_scale_vmin = color_scale_vmin or 0
# log.debug("resolve_hazard_map using LIN normalized colour scale")
# norm = mpl.colors.Normalize(vmin=color_scale_vmin, vmax=color_scale_vmax)
# return norm


# @lru_cache
# def get_colour_scale(color_scale: str, color_scale_normalise, vmax: float, vmin: float) -> HexRgbValueMapping:
# # build the colour_scale
# assert vmax * 2 == int(vmax * 2) # make sure we have a value on a 0.5 interval
# levels, hexrgbs = [], []
# cmap = mpl.colormaps[color_scale]
# norm = get_normaliser(vmax, vmin, color_scale_normalise)
# for level in range(0, int(vmax * 10) + 1):
# levels.append(level / 10)
# hexrgbs.append(mpl.colors.to_hex(cmap(norm(level / 10))))
# hexrgb = HexRgbValueMapping(levels=levels, hexrgbs=hexrgbs)
# return hexrgb
Loading

0 comments on commit 1b9fef1

Please sign in to comment.