Skip to content

Commit

Permalink
MAINT: Allow returning status for W1 - and correct convergence check
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Sep 28, 2024
1 parent 017b479 commit 8789ca6
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions src/darsia/measure/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ def __call__(

# Return solution
return_info = self.options.get("return_info", False)
return_status = self.options.get("return_status", False)
if return_info:
info.update(
{
Expand All @@ -1403,6 +1404,8 @@ def __call__(
}
)
return distance, info
elif return_status:
return distance, info["converged"]
else:
return distance

Expand Down Expand Up @@ -1693,7 +1696,7 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]:

# Define performance metric
info = {
"converged": iter < num_iter,
"converged": iter < num_iter - 1,
"number_iterations": iter,
"convergence_history": convergence_history,
"timings": total_timings,
Expand Down Expand Up @@ -1952,6 +1955,15 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]:
# Update distance
new_distance = self.l1_dissipation(flux)

# Catch nan values
if np.isnan(new_distance):
info = {
"converged": False,
"number_iterations": iter,
"convergence_history": convergence_history,
}
return new_distance, solution_i, info

# Determine the error in the mass conservation equation
mass_conservation_residual = (
self.div.dot(flux) - rhs[self.pressure_slice]
Expand Down Expand Up @@ -1993,22 +2005,26 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]:

# Print status
if self.verbose:
distance_increment = (
convergence_history["distance_increment"][-1] / new_distance
)
aux_force_increment = (
convergence_history["aux_force_increment"][-1]
/ convergence_history["aux_force_increment"][0]
)
mass_conservation_residual = convergence_history[
"mass_conservation_residual"
][-1]
print(
f"Iter. {iter} \t| {new_distance:.6e} \t| "
""
f"""{distance_increment:.6e} \t| {aux_force_increment:.6e} \t| """
f"""{mass_conservation_residual:.6e}"""
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="overflow encountered"
)
distance_increment = (
convergence_history["distance_increment"][-1] / new_distance
)
aux_force_increment = (
convergence_history["aux_force_increment"][-1]
/ convergence_history["aux_force_increment"][0]
)
mass_conservation_residual = convergence_history[
"mass_conservation_residual"
][-1]
print(
f"Iter. {iter} \t| {new_distance:.6e} \t| "
""
f"""{distance_increment:.6e} \t| {aux_force_increment:.6e} \t| """
f"""{mass_conservation_residual:.6e}"""
)

# Base stopping citeria on the different interpretations of the split Bregman
# method:
Expand Down Expand Up @@ -2058,7 +2074,7 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]:

# Define performance metric
info = {
"converged": iter < num_iter,
"converged": iter < num_iter - 1,
"number_iterations": iter,
"convergence_history": convergence_history,
"timings": total_timings,
Expand Down

0 comments on commit 8789ca6

Please sign in to comment.