Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch model extractor #298

Merged
merged 6 commits into from
Feb 20, 2025
Merged

PyTorch model extractor #298

merged 6 commits into from
Feb 20, 2025

Conversation

mastoffel
Copy link
Collaborator

@mastoffel mastoffel commented Feb 12, 2025

  • adds a function to extract PyTorch models from AutoEmulate emulators

  • emulators can be pipelines, MultiOutputRegressors etc., this function checks all the options and extracts the underlying PyTorch model where possible and throws an error for other models

  • it also gives a message saying that datapreprocessing is better turned off when doing this and has to be done manually (as it can't be attached to the PyTorch model like it can be to a sci-kit learn model using a pipeline

  • it returns the model in eval mode

  • does not yet include other objects as discussed in Add "fit only pytorch models" flag #291 . Maybe we leave that to the next PR?

Copy link
Contributor

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  autoemulate
  utils.py 426, 435, 449
  autoemulate/emulators
  conditional_neural_process.py
  gaussian_process.py
  gaussian_process_mt.py
  tests
  test_pytorch_utils.py
Project Total  

This report was generated by python-coverage-comment-action

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 96.47059% with 3 lines in your changes missing coverage. Please review.

Project coverage is 94.22%. Comparing base (7a4dc72) to head (40dbed2).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
autoemulate/utils.py 85.71% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #298      +/-   ##
==========================================
+ Coverage   94.20%   94.22%   +0.02%     
==========================================
  Files          62       63       +1     
  Lines        3606     3691      +85     
==========================================
+ Hits         3397     3478      +81     
- Misses        209      213       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@marjanfamili
Copy link
Collaborator

Thanks a lot for this PR Martin. It looks great. Would this functionality be necessary if we change everything to PyTorch? is this a temporary measure until we do ?

@mastoffel
Copy link
Collaborator Author

mastoffel commented Feb 18, 2025

Thanks a lot for this PR Martin. It looks great. Would this functionality be necessary if we change everything to PyTorch? is this a temporary measure until we do ?

@marjanfamili, I think it wouldn't be necessary if everything is changed to PyTorch! So, yes, temporary.

Copy link
Collaborator

@marjanfamili marjanfamili left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks

@mastoffel mastoffel merged commit 03d9005 into main Feb 20, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants