Skip to content

Commit

Permalink
make ahmanalysis not use libres facade
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 5, 2025
1 parent 13710e6 commit 48d9d44
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 36 deletions.
92 changes: 71 additions & 21 deletions src/semeio/workflows/ahm_analysis/ahmanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import ert
import numpy as np
import pandas as pd
from ert.analysis import ErtAnalysisError, SmootherSnapshot
import polars
import polars as pl
from ert import LibresFacade
from ert.analysis import ErtAnalysisError, SmootherSnapshot, smoother_update
from ert.config import ESSettings, Field, GenKwConfig, UpdateSettings
from ert.storage import open_storage
from scipy.stats import ks_2samp

Expand Down Expand Up @@ -105,20 +109,44 @@ class AhmAnalysisJob(SemeioScript):

def run(
self,
alpha: str | None = None,
target_name="analysis_case",
prior_name=None,
group_by="data_key",
output_dir=None,
):
# (SemeioScript wraps this run method)

"""Perform analysis of parameters change per obs group
prior to posterior of ahm"""

if isinstance(alpha, str):
alpha = float(alpha)

if output_dir is not None:
self._reports_dir = output_dir

obs_keys = list(self.facade.get_observations().obs_vectors.keys())
key_map = _group_observations(self.facade, obs_keys, group_by)
experiment = self.ensemble.experiment

observations_and_responses_mapping = (
pl.concat(
df["observation_key", "response_key"]
for df in experiment.observations.values()
)
if len(experiment.observations) > 0
else polars.DataFrame({"observation_key": [], "response_key": []})
)

response_2_obs_df = observations_and_responses_mapping.group_by(
"response_key"
).agg(pl.col("observation_key").unique())

def _replace(s: str) -> str:
return s.replace(":", "_")

# Note: Original behavior replaces : with _, omitting this
key_map = {
_replace(l["response_key"]): sorted(map(_replace, l["observation_key"]))
for l in response_2_obs_df.sort(by="response_key").to_dicts()
}

prior_name, target_name = check_names(
self.ensemble.name,
Expand All @@ -142,7 +170,7 @@ def run(
raise_if_empty(
dataframes=[
prior_data,
self.facade.load_all_misfit_data(prior_ensemble),
LibresFacade.load_all_misfit_data(prior_ensemble),
],
messages=[
"Empty prior ensemble",
Expand All @@ -155,13 +183,22 @@ def run(
# create dataframe with observations vectors (1 by 1 obs and also all_obs)
combinations = make_obs_groups(key_map)

field_parameters = sorted(self.facade.get_field_parameters())
field_parameters = [
p.name
for p in experiment.parameter_configuration.values()
if isinstance(p, Field)
]
gen_kws = [
p.name
for p in experiment.parameter_configuration.values()
if isinstance(p, GenKwConfig)
]
if field_parameters:
logger.warning(
f"AHM_ANALYSIS will only evaluate scalar parameters, skipping: {field_parameters}"
)

scalar_parameters = sorted(self.facade.get_gen_kw())
scalar_parameters = sorted(gen_kws)
# identify the set of actual parameters that was updated for now just go
# through scalar parameters but in future if easier access to field parameter
# updates should also include field parameters
Expand All @@ -183,26 +220,26 @@ def run(
# storage in a temporary directory.
with (
tempfile.TemporaryDirectory(),
open_storage("tmp_storage", "w") as storage,
open_storage("tmp_storage", "w") as tmp_storage,
):
try:
prev_experiment = prior_ensemble.experiment
experiment = storage.create_experiment(
experiment = tmp_storage.create_experiment(
parameters=prev_experiment.parameter_configuration.values(),
observations=prev_experiment.observations,
responses=prev_experiment.response_configuration.values(),
)
target_ensemble = storage.create_ensemble(
target_ensemble = tmp_storage.create_ensemble(
experiment,
name=target_name,
ensemble_size=prior_ensemble.ensemble_size,
)
update_log = _run_ministep(
self.facade,
prior_ensemble,
target_ensemble,
obs_group,
field_parameters + scalar_parameters,
alpha,
)
# Get the active vs total observation info
df_update_log = make_update_log_df(update_log)
Expand All @@ -225,7 +262,7 @@ def run(
calc_observationsgroup_misfit(
group_name,
df_update_log,
self.facade.load_all_misfit_data(prior_ensemble),
LibresFacade.load_all_misfit_data(prior_ensemble),
)
]
# Calculate Ks matrix for scalar parameters
Expand All @@ -247,14 +284,27 @@ def run(
self.reporter.publish_csv("prior", prior_data)


def _run_ministep(facade, prior_storage, target_storage, obs_group, data_parameters):
rng = np.random.default_rng(seed=facade.config.random_seed)
return facade.smoother_update(
prior_storage,
target_storage,
target_storage.name,
obs_group,
data_parameters,
def _run_ministep(
prior_storage,
target_storage,
obs_group,
data_parameters,
alpha: float | None = None,
):
es_settings = ESSettings()
obs_settings = (
UpdateSettings(alpha=alpha) if alpha is not None else UpdateSettings()
)

rng = np.random.default_rng()

return smoother_update(
prior_storage=prior_storage,
posterior_storage=target_storage,
observations=obs_group,
parameters=data_parameters,
update_settings=obs_settings,
es_settings=es_settings,
rng=rng,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
,Parameters,FOPR,All_obs,All_obs-SNAKE_OIL_WPR_DIFF,All_obs-WOPR_OP1
0,SNAKE_OIL_PARAM:BPR_138_PERSISTENCE,0.3,0.3,0.3,0.3
1,SNAKE_OIL_PARAM:BPR_555_PERSISTENCE,0.2,0.2,0.2,0.2
2,SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE,0.4,0.4,0.4,0.4
3,SNAKE_OIL_PARAM:OP1_OCTAVES,0.2,0.2,0.2,0.2
4,SNAKE_OIL_PARAM:OP1_OFFSET,0.3,0.3,0.3,0.3
5,SNAKE_OIL_PARAM:OP1_PERSISTENCE,0.2,0.2,0.2,0.2
6,SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE,0.2,0.2,0.2,0.2
7,SNAKE_OIL_PARAM:OP2_OCTAVES,0.2,0.2,0.2,0.2
8,SNAKE_OIL_PARAM:OP2_OFFSET,0.5,0.5,0.5,0.5
9,SNAKE_OIL_PARAM:OP2_PERSISTENCE,0.2,0.2,0.2,0.2
,Parameters,FOPR,All_obs,All_obs-WOPR_OP1,All_obs-SNAKE_OIL_WPR_DIFF
0,SNAKE_OIL_PARAM:BPR_138_PERSISTENCE,0.3,0.3,0.4,0.3
1,SNAKE_OIL_PARAM:BPR_555_PERSISTENCE,0.2,0.3,0.2,0.3
2,SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE,0.4,0.3,0.3,0.4
3,SNAKE_OIL_PARAM:OP1_OCTAVES,0.3,0.2,0.2,0.3
4,SNAKE_OIL_PARAM:OP1_OFFSET,0.2,0.3,0.2,0.2
5,SNAKE_OIL_PARAM:OP1_PERSISTENCE,0.2,0.2,0.1,0.2
6,SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE,0.3,0.2,0.3,0.3
7,SNAKE_OIL_PARAM:OP2_OCTAVES,0.2,0.2,0.2,0.3
8,SNAKE_OIL_PARAM:OP2_OFFSET,0.4,0.5,0.4,0.4
9,SNAKE_OIL_PARAM:OP2_PERSISTENCE,0.3,0.3,0.2,0.1
1 change: 0 additions & 1 deletion tests/workflows/ahm_analysis/test_ahm_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_make_update_log_df(snake_oil_facade, snapshot):
prior_ensemble=prior_ens,
)
log = _run_ministep(
snake_oil_facade,
prior_ens,
posterior_ens,
sorted(prior_ens.experiment.observation_keys),
Expand Down
6 changes: 3 additions & 3 deletions tests/workflows/ahm_analysis/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_ahmanalysis_run(snake_oil_facade):
ks_df = pd.read_csv(output_dir / "ks.csv")
for keys in ks_df["Parameters"].tolist():
assert keys in parameters
assert ks_df.columns[1:].tolist() == group_obs
assert set(ks_df.columns[1:].tolist()) == set(group_obs)
assert ks_df["WOPR_OP1"].max() <= 1
assert ks_df["WOPR_OP1"].min() >= 0
assert (output_dir / "active_obs_info.csv").is_file()
Expand All @@ -68,8 +68,7 @@ def test_ahmanalysis_run_deactivated_obs(snake_oil_facade, snapshot, caplog):
We simulate a case where some of the observation groups are completely
disabled by outlier detection
"""

snake_oil_facade.config.analysis_config.observation_settings.alpha = 0.1
# Note: Unattainable run unless we can pass alpha to the workflow somehow
with (
open_storage(snake_oil_facade.enspath, "w") as storage,
caplog.at_level(logging.WARNING),
Expand All @@ -79,6 +78,7 @@ def test_ahmanalysis_run_deactivated_obs(snake_oil_facade, snapshot, caplog):
ahmanalysis.AhmAnalysisJob,
storage,
experiment.get_ensemble_by_name("default"),
0.1,
)
assert "Analysis failed for" in caplog.text

Expand Down

0 comments on commit 48d9d44

Please sign in to comment.