Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No error to check convergence in output of SpatioTemporalProblem #768

Open
matthieuheitz opened this issue Dec 5, 2024 · 4 comments
Open
Assignees
Labels
documentation Improvements or additions to documentation examples

Comments

@matthieuheitz
Copy link

In the output of TemporalProblem, I'm able to look at the marginal error (I'm using balanced OT) over the iterations, to get a sense of the convergence of the algorithm (in a more detailed way than just the boolean "converged") using:
tp.solutions[(0,1)]._errors

However, with the SpatioTemporalProblem, this _error field doesn't exist.
Is that normal for the FGW problem? Can't we calculate the marginal error for it?
Otherwise, it could maybe return the error vector of the last inner Sinkhorn loop?

Thanks!

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 6, 2024

Hi @matthieuheitz ,

Thanks for raising this.

What do you mean by "it does not exist", I guess it should exist, but might be None ? At least we set it here:

self._errors = output.errors

But I do see that it might be set to False by default, as in ott-jax, it's False by default: https://github.com/ott-jax/ott/blob/690b1aed1c0519899c94dcf0ccdd84500127af61/src/ott/solvers/was_solver.py#L40 which causes it to be not saved here: https://github.com/ott-jax/ott/blob/626aad6efed729a9e167b0963f4c447a2697e119/src/ott/solvers/quadratic/gromov_wasserstein.py#L291

It thus should be possible to set it via kwargs={"store_inner_errors": True}.

Pinging @selmanozleyen who is currently working on updating moscot to ott-jax=0.5.0. @selmanozleyen , can you please verify this ?

Also, @ArinaDanilina , once we have resolved this (and updated to a new version), can we please write an example how to investigate the convergence? I.e. write an example on plotting the cost and errors, and explaining the difference?

@MUCDK MUCDK added examples documentation Improvements or additions to documentation labels Dec 6, 2024
@matthieuheitz
Copy link
Author

Thanks for your answer!

Yes, you are correct, the attribute exists, but it's equal to None.

Actually, it didn't work when I passed kwargs={"store_inner_errors": True} to solve(), but it worked when I directly passed store_inner_errors=True to solve().
The _errors field then contains a (50,200) array.

Of note, there seems to be an inconsistency between the number of (outer) iterations you provide to the function through max_iterations, and the actual number of iterations (tracked with the progress_fn), there is always one more actual iteration than what max_iteration says.
For example, when setting max_iterations=1, my callback function is called for every all 2000 inner Sinkhorn iterations, and then once more for another 2000 inner Sinkhorn iterations.
And the number of error vectors in _errors is consistent with max_iteration, but then it's also inconsistent with the actual number of iterations.
It seems that it's the error of the first outer iteration that is missing from _errors.

@matthieuheitz
Copy link
Author

matthieuheitz commented Dec 6, 2024

Oh, it might be because the first iteration is iteration 0, so max_iterations is not the max number of iterations, but it's the max iteration number.
But in that case, since e.g. the plot_costs() function plots iteration 0, it would be consistent for the error to also include the error for iteration 0.

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 9, 2024

It's true that counting starts at 0, but not sure why this is not being plotted then, following

def _plot_lines(
.

@ArinaDanilina could you please check that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation examples
Projects
None yet
Development

No branches or pull requests

4 participants