-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2d flow #2301
base: feature/diff-code
Are you sure you want to change the base?
2d flow #2301
Conversation
…or 2d and 3d flow experiments
@@ -0,0 +1,43 @@ | |||
from ml_collections import config_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a different name than "mod"? Is this the config for the 2D flow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
@@ -1031,7 +1031,7 @@ def p_losses(self, stack, hres, lres, ures, t, noise = None): | |||
loss2 = self.loss_fn(x_start, warped, reduction = 'none') | |||
loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') | |||
|
|||
return loss.mean() + loss1.mean() + loss2.mean() | |||
return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, so now the losses are not quite balanced? How did you determine the 1.7
? (nit: While it seems like both of the 1.0's are superfluous, at least one of them is because you only need to weigh 2/3 loss terms, or in general k-1 of k loss terms. It's the relative weight that is going to matter).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah, it's a typo, last one was supposed to be 0.3 so that it sums up to 3. But yeah, it's manually chosen hyperparameters, which even I'm not a fan of much. I was thinking of using the SoftAdapt paper to automatically figure out these weights
flow = Flow( | ||
dim = config.dim, | ||
channels = 3 * config.data_config["img_channel"], | ||
out_dim = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the 2D part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flow is both in 2D and 3D, the out_dim is 2 in 2d, 3 in 3d. The way it is used to warp is using scale_space_warp in 3d and flow_warp in 2d
config.sampling_steps = 6 | ||
config.loss = "l1" | ||
config.sampling_steps = 20 | ||
config.loss = "l2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a major change. It would be great if there was a clearly defined experiment behind this choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a couple of experiments with less number of epochs. But it was cluttering the project space of wandb so I deleted them
…el, minipatch, partial rollout
…sr-focal added additionally
added a version for 2d flow
only changes in forward, p_losses and sample functions of diffusion model.
3d flow used gaussian pyramids and scale space warp, 2d flow only uses flow warp
compared to cold start problem, both the new versions use swin ir that is instantiated in the diffusion model and used to upsample lowres as ures which gives context instead of hres