-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SD-126] Save model summary during benchmarking #53
Merged
Merged
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
2d1fdc3
WIP PR 126
OCarrollM 6d9c089
SD-126 Console print out of JSON material asked in ticket. Needs JSON…
OCarrollM 129f045
Update loop to use named_children
osw282 2623a30
SD-126: Code cleaned and contents saved to JSON file
OCarrollM 3d26770
SD-126 Code cleaned, added load_model and get_layers to common module…
OCarrollM f7a836d
SD-126: Changed to old datacollection function, removed imports from …
OCarrollM 0827c54
Updated script with dvc added
OCarrollM 2596baf
Try fixing corrupted DVC cache
d-lowl 62320fa
uv lock
d-lowl b1827ee
Fix imports
d-lowl e7b1d38
Add partial data to DVC
d-lowl cebcb44
Specify weights when loading
d-lowl 0779eb1
Remove prebuilt models measurements from this branch
d-lowl 60cb54a
Regenerate date with a different indentation, to avoid the corrupted …
d-lowl c75d7b0
Regenerate date with a different indentation, to avoid the corrupted …
d-lowl c69ba76
Test model loading
osw282 caae6e6
Test
osw282 668505b
Test torch hub endpoitns
osw282 65459ba
Test with version
osw282 313099d
Update working torch model repo version
osw282 e173da3
Add test file
d-lowl c79ea8b
Another test
osw282 b164482
Fix lenet input dimensions
osw282 a33d191
Add model summaries collected on the jetson
osw282 773e00b
Remove test files
d-lowl 0a10fe6
Fixed data collected on jetson
osw282 3d47471
Revert some unnecessary changes
d-lowl 67436b1
Address comments
d-lowl 05554c2
Address comments
d-lowl ddfe693
Merge branch 'develop' into SD-126-JSON-Save
d-lowl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import json | ||
from typing import Any | ||
# from torchsummary import summary | ||
from torchinfo import summary | ||
# from model.lenet import LeNet | ||
|
||
|
||
def load_model(model_name: str, model_repo: str) -> Any: | ||
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Load model from Pytorch Hub. | ||
|
||
Args: | ||
model_name: Name of model. | ||
It should be same as that in Pytorch Hub. | ||
|
||
Raises: | ||
ValueError: If loading model fails from PyTorch Hub | ||
|
||
Returns: | ||
PyTorch model | ||
""" | ||
# if model_name == "lenet": | ||
# return LeNet() | ||
# if model_name == "fcn_resnet50": | ||
# return torch.hub.load(model_repo, model_name, pretrained=True) | ||
try: | ||
return torch.hub.load(model_repo, model_name, pretrained=True) | ||
except: | ||
raise ValueError( | ||
f"Model name: {model_name} is most likely incorrect. " | ||
"Please refer https://pytorch.org/hub/ to get model name." | ||
) | ||
|
||
|
||
def get_layers(model: torch.nn.Module, name_prefix: str="") -> list[tuple[str, torch.nn.Module]]: | ||
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Recursively get all layers in a pytorch model. | ||
|
||
Args: | ||
model: the pytorch model to look for layers. | ||
name_prefix: Use to identify the parents layer. Defaults to "". | ||
|
||
Returns: | ||
a list of tuple containing the layer name and the layer. | ||
""" | ||
children = list(model.named_children()) | ||
|
||
if len(children) == 0: # No child | ||
result = [(name_prefix, model)] | ||
else: | ||
# If have children, iterate over each child. | ||
result = [] | ||
for child_name, child in children: | ||
# Recursively call get_layers on the child, appending the current | ||
# child's name to the name_prefix. | ||
layers = get_layers(child, name_prefix + "_" + child_name) | ||
result.extend(layers) | ||
|
||
return result | ||
|
||
def get_layer_info(model, input_shape): | ||
model_info = {} | ||
test = torch.randn(*input_shape) | ||
hooks = [] | ||
|
||
def register_hook(layer_name): | ||
def hook(module, input, output): | ||
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_info[layer_name] = { | ||
"input_shape": tuple(input[0].size()) if input else None, | ||
"output_shape": tuple(output.size()) if output is not None else None, | ||
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"kernel_size": getattr(module, "kernel_size", None), | ||
"stride": getattr(module, "stride", None), | ||
"padding": getattr(module, "padding", None), | ||
"type": module.__class__.__name__, | ||
# "kernal_size": layer.kernel_size if hasattr(layer, "kernel_size") else None, | ||
# "stride": layer.stride if hasattr(layer, "stride") else None, | ||
# "padding": layer.padding if hasattr(layer, "padding") else None, | ||
# "type": type(layer) | ||
} | ||
return hook | ||
|
||
for layer_name, layer in get_layers(model): | ||
hooks.append(layer.register_forward_hook(register_hook(layer_name))) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
_ = model(test) | ||
|
||
for hook in hooks: | ||
hook.remove() | ||
|
||
return model_info | ||
|
||
|
||
model = load_model("resnet18", "pytorch/vision:v0.10.0") | ||
d-lowl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
layer_info = get_layer_info(model, (1, 3, 224, 224)) | ||
# print(len(layer_info)) | ||
|
||
print(json.dumps(layer_info, indent=4, separators=(",", ": "), ensure_ascii=False, default=str)) | ||
|
||
|
||
# save to json later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uv needs locking dependencies again, after removing torchsummary from the dependencies list |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Top level comment: I'd split this file into a script under power_logging and the utility functions left here.
Just so that it can be properly called from the top level package (and the instructions for it included in the README)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On top of that. We should have a bash script that would run the said script for all the models and save the results to raw_data/prebuilt_models/{model_name}/model_summary.json