-
Notifications
You must be signed in to change notification settings - Fork 698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing metrics objects directly to create_metrics_collection
#2212
base: main
Are you sure you want to change the base?
Allow passing metrics objects directly to create_metrics_collection
#2212
Conversation
Signed-off-by: Ashwin Vaidya <[email protected]>
Signed-off-by: Ashwin Vaidya <[email protected]>
…n/refactor_engine_metrics
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2212 +/- ##
==========================================
- Coverage 80.80% 80.79% -0.01%
==========================================
Files 248 248
Lines 10859 10864 +5
==========================================
+ Hits 8775 8778 +3
- Misses 2084 2086 +2 ☔ View full report in Codecov by Sentry. |
if not ( | ||
all(isinstance(metric, str) for metric in metrics) or all(isinstance(metric, Metric) for metric in metrics) | ||
): | ||
msg = f"All metrics must be either string or Metric objects, found {metrics}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this mean that a user cannot pass the following:
from torchmetrics.classification import Accuracy, Precision, Recall
from anomalib.data import MVTec
from anomalib.engine import Engine
from anomalib.models import Padim
if __name__ == "__main__":
model = Padim()
data = MVTec()
engine = Engine(image_metrics=["F1Score", Accuracy(task="binary")])
engine.train(model, datamodule=data)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, would it be an idea to have an additional check like;
from torchmetrics.classification import Accuracy, Precision, Recall
import types
def instantiate_if_needed(metric, task="binary"):
if isinstance(metric, types.FunctionType) or isinstance(metric, type):
# If metric is a function or a class (not instantiated)
return metric(task=task)
else:
# If metric is already instantiated
return metric
Or do you think if this is overkill?
📝 Description
To test
✨ Changes
Select what type of change your PR is:
✅ Checklist
Before you submit your pull request, please make sure you have completed the following steps:
For more information about code review checklists, see the Code Review Checklist.