-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Would it be possible to use a pretrained model (such as Topview Mouse from DLC animal zoo)? #158
Comments
@hummuscience thanks for the kind words! We currently do not offer any pretrained models, though we hope to start providing some within the year. Once DLC releases the dataset used to train their TopViewMouse network we can use that to train an LP version. I'll leave this issue open for now and update it once we start working on this. We are also considering providing a pretrained model from the Facemap dataset (mouse face from different angles) and the CRIM13 dataset (top-down view of two mice, one black and one white). If anybody has additional pretrained models they would like to see (and importantly, a pointer to a labeled dataset), please let us know! |
I just checked the AnimalZoo preprint and the mentioned references. It doesn't seem that any of the actual labeled datasets are available, but in some cases, the videos are (see here for example: https://zenodo.org/records/3608658). I wonder if it would make sense to run the TopViewMouse model on these videos, extract some output frames (refine them in case of errors), and then convert the project to a DLC project. It might take some time, but could be worth the effort (for me, at least). How many frames would you aim for in this case? The facemap model would be quite useful, since many people working with head-fixed mice could benefit from it. The CRIM13 model is also interesting, but wouldn't have that many applications since it's specific for assays with two animals (one white the other black). Unless, this is a more common assay, and I am not aware of that. On the subject of multianimal tracking. Are there currently any possibilities with LP? |
We've been in touch with the DLC folks and they plan to release the labeled datasets once the paper is out. However, it is not clear how long that will take so your suggestion is probably the quickest way to getting something workable. This isn't something we have the bandwidth to work on right now, but if you're up for trying it that would be great! I'd suggest labeling all the video frames with the TopViewMouse model, then doing the following:
I would also note that we find, even with the LP bells and whistles, that more labeled frames is always better. so if you're not too daunted by the refinement step, selecting 2k or 3k frames will almost certainly result in a more robust model than 1k frames. but 1k will certainly do a good job. If you go down this route please let me know! Happy to keep discussing it with you. Re: multi-animal tracking - this is something that we are working towards, but it will be a while before we have these features built into LP (unless the animals are visibly distinct, like in the CRIM13 dataset) |
So, I started to implement this for some of my own videos and currently refining the predictions on 300 images to test if it improves things. Meanwhile, it seems like the datasets from the SuperAnimal paper are public now (or at least, I found them on zenodo: MausHaus and the whole TopViewMouse dataset. I tried the pretrained model on some images from the MausHaus and the BlackMice datasets, but the results were actually not as good as I expected. The expectation would be that the pretrained model would reproduce the labels from the training set. But it didn't (or rather, performed poorly). Now there is also the entire TopViewMouse dataset. There, not all keypoints are labelled in all images (since they come from different datasets). I wondered if I can just go ahead and train LP with it like that. The other issue is that the annotations are in a JSON file. I think I could manage to convert it to LP format though. |
I looked at the TopViewMouse demo a year or so back and also found that it did not perform as well as expected on a top-view mouse video that looked very similar to one in the training dataset. Good to know that you could replicate this finding. I still haven't had a chance to look into the TopVIewMouse dataset - I do remember that not all keypoints are labeled in all images, but are the keypoints at least named the same when they are in the same location across datasets? If so then you can definitely train and LP model on these, you would just leave the ground truth label empty where it doesn't exist and then LP ignores this keypoint during training. It shouldn't be too hard to convert the annotations from JSON to LP format. If you end up doing this please let me know and we can discuss best ways to train the model! |
The results are a bit better when one uses spatio-temporal adaptation. But not as I would expect. Yes, the positions are the same. Then I will go ahead and try out the try. I will report :) |
Awesome excited to see the results! I take it the labeled dataset doesn't have the associated videos as well? Maybe I can ask the Mathis's for that data, then we could extract the context frames and test out a context model as well. |
Training is running 👍 will post once its done. Yeah, the dataset doesn't contain any videos. I think some of the origin datasets could have videos (pranav 2018 maybe?). But yeah, it could be easier to ask Mathis's for the videos. Even though it is possible that they won't have them... Would it be possible to inform the context model with a non-context model somehow? Maybe one could use the PCA? Btw, I wrote a script that automatically extracts the context frames for each image. Could that be useful to add as a utility in the scripts folder? |
First look at the training. It seems like I should be stopping the training earlier (150k?) or at least saving more checkpoints. Test videos look very good :) I am thinking of training a DLC model witht he same dataset (maybe some shuffles?) to compare output. There are 300 additional frames not contained in the original TopViewMouse5k that come from my own datasets. Will check the evaluation I am quite new to LP so I am not so sure about the choices in the config.yaml file. I added it below.
|
I tried to run a semi-supervised model (PCA or temporal), but it fails due to the input images being of different sizes. I could rescale the images and the key points, do you think that would make sense? |
ah very cool, glad you were able to get this working!
yes pca will require the frames to be the same sizes; but actually this is an interesting use case because even if you resized the frames to be the same size the size of the animal would vary a lot from dataset to dataset, so maybe PCA wouldn't work so well anyways. I'll have to give this some more thought.
everything else looks good! I see you set imgaug to dlc-top-down already, which is great 👌
the lack of context frames and the frames being different sizes means it is difficult to play with the context/semi-supervised features of the model. there's not a real easy workaround to either of these with this dataset that I can think of off the top of my head. So I guess I'd say use the fully supervised model first and see how that works for you?
I'm kinda surprised the validation loss starts going up so early - but my intuition is that this is related to the big resize dim (512x512), I'm curious if this goes away with smaller resize dims. |
Re-training now with your suggestions. I have a RTX A6000, so I increased the batch size to 32 without any memory error (after 40 epochs). Interestingly, the model is not much faster, an epoch takes 1 minute I was checking the loss development with Tensorboard, and it seems like the loss is reducing quicker than the first model I trained. It could be due to the larger training data size and larger batch size that the loss is reducing faster and its generalizing better. |
Going through the evaluation results of the first model I trained, it seems like the model doesn't generalize across datasets where less keypoints are labelled. For example, in this dataset, only 4 keypoints were labelled. When we predict all 26 keypoints, the model does quite a bad job at it, with a high confidence. Am I doing something wrong? In the SuperAnimal paper, they mention using gradient masking of the heatmaps to deal with this issue. I have no clear idea what that means, though. |
I guess this makes sense, you have fewer batches but each batch takes longer to process. Is your GPU utilization at or near 100% while training? If not you could also increase
You are correct that this is due to the increased batch size - you can see in the first model that there is a big dip around 10k iterations. This is actually due to the unfreezing of the backbone weights. When you increase the batch size you take fewer steps per epoch, and so the backbone is unfrozen earlier (in terms of number of gradient steps). Actually in your case because there are so many frames you could probably reduce the parameter
Huh this is a bit unexpected, I would have guessed the model could generalize better than this. I'm curious if the 256x256 model generalizes better. This is a great test though! Do you see this kind of issue with other datatsets that have few labeled keypoints? I forget how exactly the SuperAnimal paper does the gradient masking but I'll look into that and get back to you. |
I found the code that DLC uses for gradient masking in their dlc_pytorch branch
If I understand this correctly, the gradient masking is to the weights before the backwards pass? But I am not sure, since in the SuperAnimal paper they mention that it is applied before the loss calculation. In LP the masking is done during [the loss calculation]https://github.com/danbider/lightning-pose/blob/4967266feb59a8c08ff3c31d08f520d480cd10d1/lightning_pose/losses/losses.py#L149), as in, all NaNs are removed before loss calculation. Is that the same approach though? Checking the SuperAnimal paper, there are some images from with vs. without masking The methods part mentions the following: Training naively on these projected annotations would harm the training stability, as the loss function penalizes undefined keypoints, as if they were not visible (i.e., occluded). For stable training of our panoptic pose estimation model, we mask components of the loss function across keypoints. The keypoint mask with Note that we make distinct the difference between not annotated and not defined in the original dataset and we only mask undefined keypoints. This is important as, in the case of sideview animals, "not annotated" could also mean occluded/invisible. Adding masking to not annotated keypoints will encourage the model to assign high likelihood to occluded keypoints. |
Does LP distinguish between "unlabelled" keypoints (because of occlusions) and keypoints that were not labelled at all? |
No, currently LP does not distinguish between the two - if a ground truth label is missing it is dropped from the loss function with the So yes, on a first pass it appears that LP by default does the same gradient masking that the SuperAnimal paper implements. Looking at the DLC function you linked, |
Oh something else I just realized: the TopViewMouse dataset contains some datasets with multiple animals - how do you deal with this right now? LP cannot currently handle multi-animal pose estimation. |
I removed the two datasets that have multiple animals (TriMouse and Golden Lab). I also just realized that one of the Datasets (Kiehn Lab Openfield) actually doesn't have labels. So I will rerun the training without it |
Great. Let me know how the generalization looks with the 256x256 model when done, I'm still scratching my head a bit about the bad performance in the frames you showed above. |
This is the current state of the training (it still has the Kiehn Lab Openfield data though). The train_heatmap_mse_loss is plateauing as well as the RMSE loss. Not sure what to think about the val loss. Should I stop the training in this case? If I stop the training, will it save the checkpoint? (I set the config to save multiple checkpoints but realized that it is not implemented on the dynamic_crop branch) |
the noisiness in the validation plot is weird, especially compared to the black line. is black 512x512 and magenta 256x256? one thing you can do is hit the three dots in the upper right hand corner of these plots and change the y-axis to a log scale, that's typically more helpful the further you get in training - my guess is that the the model should be saving out weights along the way, you'll find them in the tb_logs directory (you'll have to go down a couple more subdirectories). |
btw if you're training on the |
Alright, re-training now with the on the unsupervised multiview branch with 8 workers, and unfreeze epoch set to 5. I will also have a look at how well the magenta model performed. |
awesome, thanks so much for looking into this. yeah maybe we can switch over to the discord? https://discord.gg/tDUPdRj4BM to answer your most recent questions here though:
|
I am currently thinking that the issues with the labels on the sides of the animal getting bad predictions has to do with the relative proportion of these labels vs. others (such as ears or tail base) in the whole dataset. I will try to do some stratification with the train/test data or some oversampling. I haven't found hints of this being implemented in LP until now. Do you think it might make sense to add? |
That's a good observation, it is indeed possible that oversampling would force the model to focus more on these less-frequent keypoints. One way to do this would be to implement a custom sampler for the pytorch data loader, which would allow upweighting of the labeled examples with more keypoints. I don't have the bandwidth to work on this for the time being, but happy to discuss the details more if you're interested in trying it out. |
The more intersting question (at least for me) is what happens with this "collapsing" of these left_ and right_ keypoints to the center. Here is an example of the same frame from 4 different models. The "no_topview" model is one that was trained only on in-house data (without the TopViewMouse dataset). You can see that the lateral keypoints are "collapsing" into the spine of the animal in the other models. If I plot the distances between the left and right keypoints, moments where this "collapse" is happening are situations where the distance is close to 0. compared to the "no_topview" model (which this collapsing doesnt happen), all trained models show this. And they are not even consistent as to when it happens (even though it seems like frames where the animal stands up, narrowing its body). Here are the distributions of the distance for all trained models: In this case, it seems like inflation of the training dataset leads to a worse result. This is however not the case for all datasets. The OFT dataset which benefitted from the oversampling/inflation above has this problem solved :) So I might be getting closer to the solution. I think it has to do with how I am oversampling and weighing the data :) |
@hummuscience wow this is awesome, nice work! some misc comments/questions:
|
|
|
|
|
How do I actually do this? "All you need to do is train on the TopViewMouse and then point to those weights in the config file. The context model weights will be randomly initialized, but the ResNet-50 weights from TopViewMouse will be loaded in." I tried to pass the path to the ckpt file as backbone in the config, but that didn't work. Now I went with resnet50_animal_ap10k as backbone and the checkpoint passed to the "checkpoint" argument. Is that the correct way? |
Yes your second approach is the way to go! A little bit inefficient because it builds the backbone, then loads the ap10k weights, and then loads the weights from the checkpoint, but it's the most straightforward way to lay it out in the config file. |
Here is a small update. I am not completely done with the tests but it seems like there is a solution to the “collapsing” issue. The upper part is the pixel error on test images using different data as input (combined is GK + ZN, two different experiments). The middle part is violin plots of the distance between left and right bodyparts in the same test video. The closer to 0, the worse the model. I sorted them a bit by the average pixel error, to make it a bit easier to see patterns in the combinations. I had a suspicion that something about the size of the input images was defining whether I see "collapsing" or not. So I ran some models with 512x512 instead of 256x256 as image_size. And this does the trick. I am not well enough versed in dealing with these models but it could be that reducing the size to 256x256 gets the keypoins "too close" to each other and the model has issues in learning to differentiate them? All these models were run after pre-training on the TopViewMouse5k data with inflated+upsampled frames and 256x256 as image size. I also ran tests there after seeing these results, so I will post an update another time. But 512x512 seems to be a better choice for this data as it improves the result from datasets where the mice are small due to zoomed out images. |
That's interesting, we've fit Lightning Pose on some fairly large images with freely moving mice and 256x256 has always worked well for us. But that has always been the single dataset/single set of keypoints setting, so maybe something about this heterogeneous dataset is different. Have you tried doing some of comparisons you did previously on the TopViewMouse5k data (with inflated+upsampled frames) looking at 256x256 versus 512x512? Might also be worth checking out 384x384 as an intermediate solution. |
Just wanted to give a quick update about the current state of this. I ended up training quite a few models to find out what works best for my data. Here is a summary using 3 different backbones of the DLC topviewmouse dataset and different combinations of losses. I am only using the OOD (so test) dataset for these plots. Side difference is how the ground truth and predicted difference between the sides compare. So lower values are better. The models are ordered according to how well they perform on the combination of pixel error, side difference, percent correct keypoints (PCK), and calibration error (to account for the over-confidence). I did not alter the parameters for the different losses. One thing that is clear is that training with larger image size is generally performing better than with smaller images. The same goes for context models vs. normal models. Surprisingly, combining both unsupervised losses is nit necessarily better than single losses. What is still problematic is this "collapsing" or side switching. When checking predictions more closely, it seems to not only be single frames but actually happen "gradually" over multiple frames. I will change the colormap to make it more obvious and post an example of it. As to the question of how the size of the images in the backbone and the size of the images in the "refinement" step affect the "collapsing". It seems like both have an affect. With the size in the refinement having a larger effect (in this case, its the raw distance between left and right in the predicted videos): |
@hummuscience thanks for checking in, I continue to be very curious about these results. Some questions for clarification:
And as a more general question, how far away are the models from being usable for answering your scientific question? How many frames have you ended up labeling so far from your own dataset? |
So, I first trained 2 different 'backbones'. DLC256 and DLC512 which either resize to 256x256 or 512x512. I then used these backbones to train different models with my data of interest (freely moving mice). The backbones did not see any of my data and only used data from the DLC model.
And to the general question. I am by now using the best performing model for my actual data (but removing the "side" keypoints as the switching/collapsing is still a problem for downstream use such as Keypoint Moseq etc.). In total I have only labelled 250 images. I am thinking of adding more images to try and maybe improve this "switching" performance. |
ah I see now! ok very cool. so the general trends are:
I still think it would be worth trying out a 384x384 model just to see how it fits into the mix. |
Yeah. I think when it comes to unsupervised losses it's a bit mixed, while with model type, I would say context is a bit better. I might be choosing the wrong ways to compare the models, though. Please let me know if you have better ideas. Also, the models were all trained with The DLC team updated their Zenodo repo after I found an error in one of the datasets (which meant that I had to drop 800 frames from training the backbone before). https://zenodo.org/records/10618947 So I will retrain the backbone with the larger dataset and the different options of sampling to correct for the imbalance in labels and then run another round and update you :) |
Oof sorry to hear about the issue with the super-animal dataset - yet another good reason to make data publicly available! Regarding Regarding the model evaluation, one recommendation I have would be to look at snippets of held-out video data as well, and not just the labeled data. Since the collapsing sides is still an issue, this is fortunately a metric you can calculate on unlabeled data for all the models - just compute the pixel difference between two selected side keypoints for each frame. Then you could take 5-10 short clips of held-out mice, running inference with some of the models, and then comparing the side differences. I would also recommend checking out our EKS post-processor - it requires training and inference with multiple models, but we've really found it to be helpful in cleaning up noise in the predictions (as long as the different models (trained on different subsets of the data) aren't all making the same mistake). |
True. For the backbone, its bad because of the different labels. Will only activate it for the finetuning. Regarding your suggestion, this is exactly what I did in the second plot. The distributions are from a single held-out video (15 minutes). Will add a few more videos to have more variability. I tried the EKS post-processor. But as you said, since many of the models are doing similar mistakes, the outcomes were not worth the trouble. Will definitely check it out once things are solved. |
Ah I see that makes a lot of sense - I assumed those were distances computed on labeled frames. I'm surprised that the pca loss actually has a shorter distance. Is one of these distance distributions more "correct" than the others? Also, do you have a sense if the pca loss is good at signaling the side collapse issue? It's interesting that multiple models would have the side collapse issue on the same frames. I guess that means there's something specific about the frames in which the collapse occurs - do you have any sense of this? (sorry if you mentioned this previously) If so, it would be helpful to identify such examples across different animals and add those frames to your labeled data. |
I have been playing around with lightning pose the past few days and quite impressed with the training speed and performance!
Coming from DeepLabCut, I am testing LP on videos of mice captured from the top view. As you probably know, the DLC animal zoo had a pretrained model for this scenario .
Would it be possible to use that as a backbone for LP instead of the typical resnets?
I am still new to your codebase, so I might have not understood it in depth yet and missing something here...
The text was updated successfully, but these errors were encountered: