Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR visualizes how the weights are loaded into TransformerLens from HuggingFace for all the different models. The user can access this visualization via two static functions from the class
WeightConversionUtils
. The function,model_info
takes in the name of the model and prints out the visualization, whereas the functionmodel_info_cfg
takes in a HookedTransformerConfig but prints out the same visualization.The visualizations works by implementing a
__repr__
function for all the different conversion steps, and then calling them for all the different conversion steps performed by a given model architecture.I also implemented a little fix, because the Gemma architecture was returning None for the blocks_conversion function, which resulted in the weights not being properly loaded into the different blocks of the Gemma architecture
I also had to move some imports to exactly where they were needed, because I would otherwise get circular import errors. I'm happy to revert that change if that is a problem that was only occurring for me locally.
There is no specific issue attached to this PR.
Type of change
Screenshots
This is what the output looks like for the function call
Checklist: