Skip to content

Commit f9702cf

Browse files
author
KevinMusgrave
committed
Updated the documentation
1 parent 1affee7 commit f9702cf

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

docs/accuracy_calculation.md

+27-3
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ The AccuracyCalculator class computes several accuracy metrics given a query and
66
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
77
AccuracyCalculator(include=(), exclude=(), avg_of_avgs=False, k=None)
88
```
9-
**Parameters**:
9+
### Parameters
1010

1111
* **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated.
1212
* **exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate.
1313
* **avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned.
1414
* **k**: If set, this number of nearest neighbors will be retrieved for metrics that require k-nearest neighbors. If None, the value of k will be determined automatically.
1515

16-
**Getting accuracy**:
16+
### Getting accuracy
1717

1818
Call the ```get_accuracy``` method to obtain a dictionary of accuracies.
1919
```python
@@ -42,7 +42,31 @@ def get_accuracy(self,
4242
* **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all metrics specified during initialization will be calculated.
4343
* **exclude**: Optional. A list or tuple of strings, which are the names of metrics you do not want to calculate.
4444

45-
**Adding custom accuracy metrics**
45+
### Explanations of the default accuracy metrics
46+
47+
- **AMI**:
48+
49+
- [scikit-learn article](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html)
50+
- [Wikipedia](https://en.wikipedia.org/wiki/Adjusted_mutual_information)
51+
52+
- **NMI**:
53+
54+
- [scikit-learn article](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html)
55+
- [Slides from Northeastern University](https://course.ccs.neu.edu/cs6140sp15/7_locality_cluster/Assignment-6/NMI.pdf)
56+
57+
- **mean_average_precision_at_r**:
58+
59+
- [See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf)
60+
61+
- **precision_at_1**:
62+
63+
- Fancy way of saying "is the 1st nearest neighbor correct?"
64+
65+
- **r_precision**:
66+
67+
- [See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf)
68+
69+
### Adding custom accuracy metrics
4670

4771
Let's say you want to use the existing metrics but also compute precision @ 2, and a fancy mutual info method. You can extend the existing class, and write methods that start with the keyword ```calculate_```
4872

docs/testers.md

+13-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ tester.test(dataset_dict, epoch, model)
1111
# Or if your model is composed of a trunk + embedder
1212
tester.test(dataset_dict, epoch, trunk, embedder)
1313
```
14+
You can perform custom actions by writing an end-of-testing hook (see the documentation for [BaseTester](#basetester)), and you can access the test results directly via the ```all_accuracies``` attribute:
15+
```python
16+
print(tester.all_accuracies)
17+
```
18+
This will print out a dictionary of accuracy metrics, per dataset split. You'll see something like this:
19+
```python
20+
{"train": {"AMI_level0": 0.53, ...}, "val": {"AMI_level0": 0.44, ...}}
21+
```
22+
Each of the accuracy metric names is appended with ```level0```, which refers to the 0th label hierarchy level (see the documentation for [BaseTester](#basetester)). This is only relevant if you're dealing with multi-label datasets.
23+
24+
For an explanation of the default accuracy metrics, see the [AccuracyCalculator documentation](accuracy_calculation.md#explanations-of-the-default-accuracy-metrics).
25+
1426

1527
## BaseTester
1628
All trainers extend this class and therefore inherit its ```__init__``` arguments.
@@ -48,7 +60,6 @@ testers.BaseTester(reference_set="compared_to_self",
4860
* **label_hierarchy_level**: If each sample in your dataset has multiple labels, then this integer argument can be used to select which "level" to use. This assumes that your labels are "2-dimensional" with shape (num_samples, num_hierarchy_levels). Leave this at the default value, 0, if your data does not have multiple labels per sample.
4961
* **end_of_testing_hook**: This is an optional function that has one input argument (the tester object) and performs some action (e.g. logging data) at the end of testing.
5062
* You'll probably want to access the accuracy metrics, which are stored in ```tester.all_accuracies```. This is a nested dictionary with the following format: ```tester.all_accuracies[split_name][metric_name] = metric_value```
51-
* If you set ```size_of_tsne``` to be greater than 0, then the T-SNE embeddings will be stored in ```tester.tsne_embeddings``` which is a dictionary with the following format: ```tester.tsne_embeddings[split_name]["tsne_level%d"] = (embeddings, labels)```. (Note that ```"tsne_level%d"``` refers to the label hierarchy level. If you use the default label hierarchy level, then the string will be ```"tsne_level0"```.)
5263
* If you want ready-to-use hooks, take a look at the [logging_presets module](logging_presets.md).
5364
* **dataset_labels**: The labels for your dataset. Can be 1-dimensional (1 label per datapoint) or 2-dimensional, where each row represents a datapoint, and the columns are the multiple labels that the datapoint has. Labels can be integers or strings. **This option needs to be specified only if ```set_min_label_to_zero``` is True.**
5465
* **set_min_label_to_zero**: If True, labels will be mapped such that they represent their rank in the label set. For example, if your dataset has labels 5, 10, 12, 13, then at each iteration, these would become 0, 1, 2, 3. You should also set this to True if you want to use string labels. In that case, 'dog', 'cat', 'monkey' would get mapped to 1, 0, 2. If True, you must pass in ```dataset_labels``` (see above). The default is False.
@@ -64,7 +75,7 @@ testers.BaseTester(reference_set="compared_to_self",
6475

6576

6677
## GlobalEmbeddingSpaceTester
67-
Computes nearest neighbors by looking at all points in the embedding space. This is probably the tester you are looking for. To see it in action, check one of the [example notebooks](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples)
78+
Computes nearest neighbors by looking at all points in the embedding space (rather than a subset). This is probably the tester you are looking for. To see it in action, check one of the [example notebooks](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples)
6879
```python
6980
testers.GlobalEmbeddingSpaceTester(*args, **kwargs)
7081
```

0 commit comments

Comments
 (0)