Skip to content

Commit

Permalink
[Fix] Fix nested predict for multi-task prediction. (open-mmlab#1716)
Browse files Browse the repository at this point in the history
* fix: multi task predict

* change the loop

---------

Co-authored-by: Pierre Colle <[email protected]>
  • Loading branch information
marouaneamz and piercus authored Jul 28, 2023
1 parent 64c446d commit e7fc25c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
14 changes: 13 additions & 1 deletion mmpretrain/models/heads/multi_task_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,24 @@ def predict(
predictions_dict = dict()

for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
task_samples = None
if data_samples is not None:
task_samples = [
data_sample.get(task_name, None) if data_sample else None
for data_sample in data_samples
]

task_samples = head.predict(feats, task_samples)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples

if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
else:
data_samples = [
MultiTaskDataSample() if data_sample is None else data_sample
for data_sample in data_samples
]

for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):
Expand Down
28 changes: 27 additions & 1 deletion tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def test_predict(self):
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
# without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for pred in predictions:
Expand All @@ -564,6 +564,32 @@ def test_predict(self):
self.assertIs(sample, pred)
self.assertIn('task0', pred)

# with data samples and nested
head_nested = MODELS.build(self.DEFAULT_ARGS2)
# adding a None data sample at the beginning
data_samples_nested = [None]
for _ in range(3):
data_sample_nested = MultiTaskDataSample()
data_sample_nested0 = MultiTaskDataSample()
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task00')
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task01')
data_sample_nested.set_field(data_sample_nested0, 'task0')
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
data_samples_nested.append(data_sample_nested)

predictions = head_nested.predict(feats, data_samples_nested)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for i in range(3):
sample = data_samples_nested[i + 1]
pred = predictions[i + 1]
self.assertIn('task0', pred)
self.assertIn('task1', pred)
self.assertIn('task01', pred.get('task0'))
self.assertEqual(
sample.get('task0').get('task01').gt_label.numpy()[0], 1)

def test_loss_empty_data_sample(self):
feats = (torch.rand(4, 10), )
data_samples = []
Expand Down

0 comments on commit e7fc25c

Please sign in to comment.