Skip to content

Commit

Permalink
Allow macroaverages with nonequal number of subreports. (facebookrese…
Browse files Browse the repository at this point in the history
…arch#2542)

* Allow macroaverages with nonequal number of subreports.

* Switch to dict powering MacroAverageMetric.

* Fix some type errors.
  • Loading branch information
stephenroller authored Apr 9, 2020
1 parent 4d2899a commit c4cca92
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 18 deletions.
22 changes: 10 additions & 12 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,21 @@ class MacroAverageMetric(Metric):
AverageMetrics already.
"""

__slots__ = ('_values',)
__slots__ = '_values'

def __init__(self, metrics: List[Metric]) -> None:
def __init__(self, metrics: Dict[str, Metric]) -> None:
self._values = metrics

def __add__(self, other: Optional['MacroAverageMetric']) -> 'MacroAverageMetric':
if other is None:
return self
if len(self._values) != len(other._values):
raise AssertionError(
"MacroAverage keeping track of an uneven number of submetrics. "
"There should be exactly one per task."
)
return MacroAverageMetric([a + b for a, b in zip(self._values, other._values)])
output = dict(**self._values)
for k, v in other._values.items():
output[k] = output.get(k, None) + v
return MacroAverageMetric(output)

def value(self) -> float:
sum_ = sum(v.value() for v in self._values)
sum_ = sum(v.value() for v in self._values.values())
n = len(self._values)
return sum_ / n

Expand Down Expand Up @@ -510,7 +508,7 @@ def aggregate_named_reports(

# reporters is a list of teachers or worlds
m: Dict[str, Metric] = {}
macro_averages: Dict[str, List[Metric]] = {}
macro_averages: Dict[str, Dict[str, Metric]] = {}
for task_id, task_report in named_reports.items():
for each_metric, value in task_report.items():
if value.is_global:
Expand All @@ -526,8 +524,8 @@ def aggregate_named_reports(
else:
# macro average
if each_metric not in macro_averages:
macro_averages[each_metric] = []
macro_averages[each_metric].append(value)
macro_averages[each_metric] = {}
macro_averages[each_metric][task_id] = value
for key, values in macro_averages.items():
m[key] = MacroAverageMetric(values)
return m
Expand Down
6 changes: 3 additions & 3 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from parlai.core.loader import load_task_module, load_world_module
from parlai.core.metrics import aggregate_named_reports
from parlai.core.opt import Opt
from parlai.core.teachers import create_task_agent_from_taskname
from parlai.core.teachers import Teacher, create_task_agent_from_taskname
from parlai.utils.misc import Timer, display_messages
from parlai.tasks.tasks import ids_to_tasks

Expand Down Expand Up @@ -612,7 +612,7 @@ def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
weight = 1
self.cum_task_weights[i] = weight + sum
sum += weight
task_ids = {}
task_ids: Dict[str, Teacher] = {}
# Having overlap in teacher ids will cause issues for metrics aggregation.
for each_world in self.worlds:
world_id = each_world.getID()
Expand All @@ -625,7 +625,7 @@ def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
)
)
else:
task_ids[world_id] = each_world.get_agents()[0]
task_ids[world_id] = each_world.get_task_agent()

def num_examples(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions parlai/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ def make_parallel(self, model: torch.nn.Module) -> torch.nn.Module:
self.__device_allocations['cuda:0'] += trainable_parameters(model) * 3

model.apply(self._place_modulelist)
model._apply(self._move_rest_to_cuda0)
model._apply(self._move_rest_to_cuda0) # type: ignore
return model

def _move_rest_to_cuda0(self, parameter):
def _move_rest_to_cuda0(self, parameter: torch.Tensor):
if parameter.device.type == 'cpu':
return parameter.to('cuda:0')
else:
Expand Down
25 changes: 24 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_macroaverage_additions(self):
m2 = AverageMetric(3, 4)

assert (m1 + m2) == AverageMetric(4, 7)
assert MacroAverageMetric([m1, m2]) == 0.5 * (1.0 / 3 + 3.0 / 4)
assert MacroAverageMetric({'a': m1, 'b': m2}) == 0.5 * (1.0 / 3 + 3.0 / 4)


class TestMetrics(unittest.TestCase):
Expand Down Expand Up @@ -237,6 +237,29 @@ def test_macro_aggregation(self):
assert agg['b/fixed'] == 4
assert 'b/global_avg' not in agg

def test_uneven_macro_aggrevation(self):
report1 = {
'avg': AverageMetric(1, 1),
}
report2 = {
'avg': AverageMetric(0, 1),
}
report3 = {
'avg': AverageMetric(0, 1),
}
agg1 = aggregate_named_reports(
{'a': report1, 'b': report2}, micro_average=False
)
agg2 = aggregate_named_reports({'a': {}, 'c': report3}, micro_average=False)

agg = aggregate_unnamed_reports([agg1, agg2])
assert agg1['avg'] == 0.5
assert agg2['avg'] == 0.0
assert agg['a/avg'] == 1.0
assert agg['b/avg'] == 0.0
assert agg['c/avg'] == 0.0
assert agg['avg'] == 1.0 / 3

def test_micro_aggregation(self):
report1 = {
'avg': AverageMetric(3, 4),
Expand Down

0 comments on commit c4cca92

Please sign in to comment.