Skip to content

Commit

Permalink
version bump v0.0.1.dev8 - renamed dual flatness to flatness components
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Mar 19, 2021
1 parent 787dd8f commit 5c30f6d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 23 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

<p align="center">
<h1 align="center">🧶 Disent</h1>
<p align="center">⚠️ W.I.P</p>
<p align="center">
<i>A modular disentangled representation learning framework for pytorch</i>
</p>
Expand Down Expand Up @@ -29,9 +28,13 @@
</p>

<p align="center">
<p align="center">⚠️ API is not yet stable</p>
<p align="center">
Visit the <a href="https://disent.dontpanic.sh/">docs</a> for more info, or browse the <a href="https://github.com/nmichlo/disent/releases">releases</a>.
</p>
<p align="center">
<a href="https://github.com/nmichlo/disent/issues/new/choose">Contributions</a> are welcome!
</p>
</p>

----------------------
Expand Down Expand Up @@ -133,10 +136,10 @@ submit an issue if you have a request for an additional framework.
+ [SAP](https://arxiv.org/abs/1711.00848)
+ [Unsupervised Scores](https://github.com/google-research/disentanglement_lib)
+ 🧵 Flatness Score
- Measures max width over path length of factor traversal embeddings, a combined measure of linearity and ordering.
+ 🧵 Dual Flatness - Linearity & Ordering
- Measures max width (furthest two points) over path length (sum of distances between consecutive points) of factor traversal embeddings. A combined measure of linearity, monotonicity and ordering.
+ 🧵 Flatness Components - Linearity, Monotonicity & Ordering
- Measure **linearity** of factor traversal embeddings using average Pearson's correlation matrices
- Measure **ordering** of factor traversal embedding using average Spearman's rank correlation matrices
- Measure **monotonicity** of factor traversal embedding using average Spearman's rank correlation matrices
- Measure **ordering** of embeddings by checking anchor-positive and anchor-negative distances correspond to ground-truth factors

Some popular metrics still need to be added, please submit an issue if you wish to
Expand Down
30 changes: 15 additions & 15 deletions disent/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Nathan Michlo et. al
from ._flatness import metric_flatness
from ._dual_flatness import metric_dual_flatness
from ._flatness_components import metric_flatness_components


# ========================================================================= #
Expand All @@ -43,21 +43,21 @@


FAST_METRICS = {
'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes
'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds
'flatness': _wrapped_partial(metric_flatness, factor_repeats=128),
'dual_flatness': _wrapped_partial(metric_dual_flatness, factor_repeats=128),
'mig': _wrapped_partial(metric_mig, num_train=2000),
'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000),
'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000),
'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes
'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds
'flatness': _wrapped_partial(metric_flatness, factor_repeats=128),
'flatness_components': _wrapped_partial(metric_flatness_components, factor_repeats=128),
'mig': _wrapped_partial(metric_mig, num_train=2000),
'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000),
'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000),
}

DEFAULT_METRICS = {
'dci': metric_dci,
'factor_vae': metric_factor_vae,
'flatness': metric_flatness,
'dual_flatness': metric_dual_flatness,
'mig': metric_mig,
'sap': metric_sap,
'unsupervised': metric_unsupervised,
'dci': metric_dci,
'factor_vae': metric_factor_vae,
'flatness': metric_flatness,
'flatness_components': metric_flatness_components,
'mig': metric_mig,
'sap': metric_sap,
'unsupervised': metric_unsupervised,
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# ========================================================================= #


def metric_dual_flatness(
def metric_flatness_components(
ground_truth_dataset: GroundTruthDataset,
representation_function: callable,
factor_repeats: int = 1024,
Expand Down Expand Up @@ -211,7 +211,7 @@ def print_r(name, steps, result, clr=colors.lYLW, t: Timer = None):
def calculate(name, steps, dataset, get_repr):
global aggregate_measure_distances_along_factor
with Timer() as t:
r = metric_dual_flatness(dataset, get_repr, factor_repeats=64, batch_size=64)
r = metric_flatness_components(dataset, get_repr, factor_repeats=64, batch_size=64)
results.append((name, steps, r))
print_r(name, steps, r, colors.lRED, t=t)
print(colors.GRY, '='*100, colors.RST, sep='')
Expand Down
2 changes: 1 addition & 1 deletion experiment/config/metrics/all.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _group_
metric_list:
- flatness:
- dual_flatness:
- flatness_components:
- mig:
- sap:
- unsupervised:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="[email protected]",

version="0.0.1.dev7",
version="0.0.1.dev8",
python_requires="==3.8",
packages=setuptools.find_packages(),

Expand Down

0 comments on commit 5c30f6d

Please sign in to comment.