Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions on how to transfer your concept into other similar projects #13

Open
Ltwicke opened this issue Nov 26, 2024 · 0 comments
Open

Comments

@Ltwicke
Copy link

Ltwicke commented Nov 26, 2024

Hi all,

I came across your paper when searching for methods to combat mode collapse when training MDNs and I am wondering if you have any thoughts whether and to what degree my case can profit from your suggested training procedure. I know, this is usually more of an email topic but I want to make this public in case other researchers with similar projects have similar questions.

Very short summary of my project:

In the course of my master thesis in biophysics, I decided to train gaussian mixture models on 1D distributions of molecular properties that were calculated at quantum level theory. Here are the similarities to your work:

  Your work My work
Context-aware input History of frames, object masks Diverse molecular input representations
Ground-truth distribution Simulated synthetic dataset with policies Fitted distribution of full set of values at desired chemical accuracy
Individual y-hats Individual future ground truth locations Individual values in one property distribution
Goal Predict the full future distribution based on context Predict the full property distribution(s) based on context
Architecture  Encoder encodes frames and predicts hypotheses, decoder generates mixture based on hypotheses (Currently) Encoder encodes molecular input, decoder takes context vector and decodes into mixture with LSTMs  

In my work, the distributions are 1D however, and they can be described with a mix of 10 gaussians at the desired resolution. Here are some visual examples:

image
(Left: blue: ground truth distributions, gray histograms: individual ys, ignore the stars,
Right: red: predicted distributions of trained network, green: ground truth distribution, distributions are already normalized for training, the support is just manually increased to influence desirable gradient behaviour)

The problem that got me here:

I use a classic encoder decoder architecture with a bottleneck latent representation of size 1024. During training, I notice that some properties generally suffer from mode collapse, while others only slightly do (see picture on the right). It is important to note that this is the case for training data (and therefore also for validation). So the problem lies deeper than just insufficient generalization.

I already came to the conclusion to train the encoder and decoder individually in an alternating fashion, but this only slightly increases performance. I suspect that the current encoder is not really capable to encode the full location and scale of the ground truth distribution effectively. The decoder collapses into predicting a sequence of tiplets ($\mu_k, \sigma_k, \pi_k$) that together reconstruct a unimodal distribution spanning the entire multimodal ground truth distribution (as shown in picture).

How you can help me:

If I understand correctly, the encoder does most of the heavy work in your architecture. Visually, id expect the hypotheses to move to individual modes for multimodal distributions in 1D in my case, while for unimodal ground-truth distributions, they will all end up at the one single mode. What do you think;

  1. Given the hope that the encoder can create meaningful hypotheses, would they spread around to the individual modes? How should I tune the hyperparameter K, given that i have at most 10 peaks?

  2. If the encoder were unable to map context to meaningful hypotheses, how would this most likely manifest in the prediction performance? Would it still just predict hypotheses at a single mode, or rather try to spread out the hypotheses to catch at least something on the real axis?

  3. Where exactly can I find the latent representation of the input (in my case molecule) in your architecture? Is it one layer before the hypotheses are predicted? If I were to do a dimensionality reduction analysis to compare where molecules are mapped to in regard of their distributions complexity, which intermediate output would I use?

  4. As is common in deep learning with molecules, the dataset split is done by separating molecules with similar motives, making the train-val-test split particularly challenging. Im expecting that overfitting will become a problem very fast. How and where can I regularize the architecture to slow down the overfitting process?

  5. What would be the best way to integrate your code into another project? Should i use EWTA_MDF() as a wrapper and build everything into this? I will need to rewrite this for pytorch anyway, are there any tensorflow mechanics involved that cannot be directly transcribed into pytorch?

Finally,

i know this is a very big ask and I don’t even know, if the issues section is meant for content like this, but I think good general answers how to use your findings will not only help me, but also others that come along with similar situations. Im very thankful for any kind of feedback to the extend possible given this short introduction to the project and I will definitely give an update on the final performance using you suggested method.

Thank you and have a wonderful day,
A fellow deep learning engineer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant