You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I am trying to use diffrax as an sde solver to reverse a diffusion process in the context of diffusion models for machine learning.
In particular, given the SDE of the forward diffusion process, I want to solve the reverse SDE to be able to sample new data.
In every implementation I've seen it seems that they are reinventing the wheel and reimplementing an SDE solver, which I would like to avoid.
Coming to my problem, when I implement "manually" the Euler algorithm, I obtain reasonable results (I obtain samples close to every point of the training dataset with the same probability). On the other hand, when using diffrax to solve the reverse SDE, not every solution seems equiprobable, which is very odd, and problematic for this application.
I attach a colab notebook with a minimal example of the problem for a very simple dataset:
In this case, I don't implement the NN which would approximate the score, and compute the score directly for the dataset, as it is very simple, to test whether the SDE solver behaves as expected.
If you could help me figure out this problem I would really appreciate it. I am not too familiar with both diffusion models and this library, so I don't have a good intuition, but could it be a problem of the rng somehow?
Thank you very much for your help and time!
The text was updated successfully, but these errors were encountered:
So the key thing to spot is that the denser points are aligned with the line x=y. This is suggestive that you are getting correlated noise between your two evolving dimensions. And indeed this:
I appreciate the use of Lineax there probably feels like it's coming out of the left field! I'm giving you the most efficient way to represent a vector field of this type. You could also write jnp.diag(jnp.broadcast_to(dispersion(1 - t), (2,)), it just wouldn't be as computationally efficient.
FWIW you're not the first to run into this footgun. For this reason I'm planning on making the ControlTerm API much stricter, so that broadcasting between the diffusion and the brownian motion isn't allowed!
Thank you so much for the help! I imagined the shape would have had to be 2, but I couldn't figure out why it was working even without a shape, and what was the appropriate shape for the dispersion term (in hindsight, now it makes sense).
Hello,
I am trying to use diffrax as an sde solver to reverse a diffusion process in the context of diffusion models for machine learning.
In particular, given the SDE of the forward diffusion process, I want to solve the reverse SDE to be able to sample new data.
In every implementation I've seen it seems that they are reinventing the wheel and reimplementing an SDE solver, which I would like to avoid.
Coming to my problem, when I implement "manually" the Euler algorithm, I obtain reasonable results (I obtain samples close to every point of the training dataset with the same probability). On the other hand, when using diffrax to solve the reverse SDE, not every solution seems equiprobable, which is very odd, and problematic for this application.
I attach a colab notebook with a minimal example of the problem for a very simple dataset:
https://colab.research.google.com/drive/1V1nU3vn9hkZvWcJnaYYWOwnsOvBB9iSd?usp=sharing
In this case, I don't implement the NN which would approximate the score, and compute the score directly for the dataset, as it is very simple, to test whether the SDE solver behaves as expected.
If you could help me figure out this problem I would really appreciate it. I am not too familiar with both diffusion models and this library, so I don't have a good intuition, but could it be a problem of the rng somehow?
Thank you very much for your help and time!
The text was updated successfully, but these errors were encountered: