You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
***include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated.
12
12
***exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate.
13
13
***avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned.
14
14
***k**: If set, this number of nearest neighbors will be retrieved for metrics that require k-nearest neighbors. If None, the value of k will be determined automatically.
15
15
16
-
**Getting accuracy**:
16
+
### Getting accuracy
17
17
18
18
Call the ```get_accuracy``` method to obtain a dictionary of accuracies.
19
19
```python
@@ -42,7 +42,31 @@ def get_accuracy(self,
42
42
***include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all metrics specified during initialization will be calculated.
43
43
***exclude**: Optional. A list or tuple of strings, which are the names of metrics you do not want to calculate.
-[Slides from Northeastern University](https://course.ccs.neu.edu/cs6140sp15/7_locality_cluster/Assignment-6/NMI.pdf)
56
+
57
+
-**mean_average_precision_at_r**:
58
+
59
+
-[See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf)
60
+
61
+
-**precision_at_1**:
62
+
63
+
- Fancy way of saying "is the 1st nearest neighbor correct?"
64
+
65
+
-**r_precision**:
66
+
67
+
-[See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf)
68
+
69
+
### Adding custom accuracy metrics
46
70
47
71
Let's say you want to use the existing metrics but also compute precision @ 2, and a fancy mutual info method. You can extend the existing class, and write methods that start with the keyword ```calculate_```
# Or if your model is composed of a trunk + embedder
12
12
tester.test(dataset_dict, epoch, trunk, embedder)
13
13
```
14
+
You can perform custom actions by writing an end-of-testing hook (see the documentation for [BaseTester](#basetester)), and you can access the test results directly via the ```all_accuracies``` attribute:
15
+
```python
16
+
print(tester.all_accuracies)
17
+
```
18
+
This will print out a dictionary of accuracy metrics, per dataset split. You'll see something like this:
Each of the accuracy metric names is appended with ```level0```, which refers to the 0th label hierarchy level (see the documentation for [BaseTester](#basetester)). This is only relevant if you're dealing with multi-label datasets.
23
+
24
+
For an explanation of the default accuracy metrics, see the [AccuracyCalculator documentation](accuracy_calculation.md#explanations-of-the-default-accuracy-metrics).
25
+
14
26
15
27
## BaseTester
16
28
All trainers extend this class and therefore inherit its ```__init__``` arguments.
***label_hierarchy_level**: If each sample in your dataset has multiple labels, then this integer argument can be used to select which "level" to use. This assumes that your labels are "2-dimensional" with shape (num_samples, num_hierarchy_levels). Leave this at the default value, 0, if your data does not have multiple labels per sample.
49
61
***end_of_testing_hook**: This is an optional function that has one input argument (the tester object) and performs some action (e.g. logging data) at the end of testing.
50
62
* You'll probably want to access the accuracy metrics, which are stored in ```tester.all_accuracies```. This is a nested dictionary with the following format: ```tester.all_accuracies[split_name][metric_name] = metric_value```
51
-
* If you set ```size_of_tsne``` to be greater than 0, then the T-SNE embeddings will be stored in ```tester.tsne_embeddings``` which is a dictionary with the following format: ```tester.tsne_embeddings[split_name]["tsne_level%d"] = (embeddings, labels)```. (Note that ```"tsne_level%d"``` refers to the label hierarchy level. If you use the default label hierarchy level, then the string will be ```"tsne_level0"```.)
52
63
* If you want ready-to-use hooks, take a look at the [logging_presets module](logging_presets.md).
53
64
***dataset_labels**: The labels for your dataset. Can be 1-dimensional (1 label per datapoint) or 2-dimensional, where each row represents a datapoint, and the columns are the multiple labels that the datapoint has. Labels can be integers or strings. **This option needs to be specified only if ```set_min_label_to_zero``` is True.**
54
65
***set_min_label_to_zero**: If True, labels will be mapped such that they represent their rank in the label set. For example, if your dataset has labels 5, 10, 12, 13, then at each iteration, these would become 0, 1, 2, 3. You should also set this to True if you want to use string labels. In that case, 'dog', 'cat', 'monkey' would get mapped to 1, 0, 2. If True, you must pass in ```dataset_labels``` (see above). The default is False.
Computes nearest neighbors by looking at all points in the embedding space. This is probably the tester you are looking for. To see it in action, check one of the [example notebooks](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples)
78
+
Computes nearest neighbors by looking at all points in the embedding space (rather than a subset). This is probably the tester you are looking for. To see it in action, check one of the [example notebooks](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples)
0 commit comments