Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] odd behaviour when solving SDE for differential initial conditions #560

Open
aurelio-amerio opened this issue Jan 5, 2025 · 2 comments
Labels
question User queries

Comments

@aurelio-amerio
Copy link

Hello,
I am trying to use diffrax as an sde solver to reverse a diffusion process in the context of diffusion models for machine learning.
In particular, given the SDE of the forward diffusion process, I want to solve the reverse SDE to be able to sample new data.
In every implementation I've seen it seems that they are reinventing the wheel and reimplementing an SDE solver, which I would like to avoid.

Coming to my problem, when I implement "manually" the Euler algorithm, I obtain reasonable results (I obtain samples close to every point of the training dataset with the same probability). On the other hand, when using diffrax to solve the reverse SDE, not every solution seems equiprobable, which is very odd, and problematic for this application.

I attach a colab notebook with a minimal example of the problem for a very simple dataset:

https://colab.research.google.com/drive/1V1nU3vn9hkZvWcJnaYYWOwnsOvBB9iSd?usp=sharing

In this case, I don't implement the NN which would approximate the score, and compute the score directly for the dataset, as it is very simple, to test whether the SDE solver behaves as expected.

If you could help me figure out this problem I would really appreciate it. I am not too familiar with both diffusion models and this library, so I don't have a good intuition, but could it be a problem of the rng somehow?

Thank you very much for your help and time!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 5, 2025

So the key thing to spot is that the denser points are aligned with the line x=y. This is suggestive that you are getting correlated noise between your two evolving dimensions. And indeed this:

brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-5, shape=(), key=keys[1])

is requesting only a scalar shape -- not a shape of (2,)! The same value is being broadcast to both dimensions.

The fix is to adjust the Brownian motion:

- brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-5, shape=(), key=keys[1])
+ brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-5, shape=(2,), key=keys[1])

and correspondingly also adjust the diffusion to be diagonal with constant value:

+ import lineax as lx
+
  def disp(t,y,args):
-     return dispersion(1-t)
+     struct = jax.eval_shape(lambda: y)
+     return lx.IdentityLinearOperator(struct, struct) * dispersion(1-t)

I appreciate the use of Lineax there probably feels like it's coming out of the left field! I'm giving you the most efficient way to represent a vector field of this type. You could also write jnp.diag(jnp.broadcast_to(dispersion(1 - t), (2,)), it just wouldn't be as computationally efficient.

FWIW you're not the first to run into this footgun. For this reason I'm planning on making the ControlTerm API much stricter, so that broadcasting between the diffusion and the brownian motion isn't allowed!

@patrick-kidger patrick-kidger added the question User queries label Jan 5, 2025
@aurelio-amerio
Copy link
Author

Thank you so much for the help! I imagined the shape would have had to be 2, but I couldn't figure out why it was working even without a shape, and what was the appropriate shape for the dispersion term (in hindsight, now it makes sense).

Thank you again, and have a great day!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants