Skip to content

Commit

Permalink
fixed aggregation when 0-dim torch tensors on gpu are passed
Browse files Browse the repository at this point in the history
  • Loading branch information
BulatVakhitov committed Aug 1, 2023
1 parent 538a2ea commit 84e8ab1
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,10 @@ def aggregate_microbatches(self, outputs, chunked_outputs, chunk_sizes, single_o
for chunk_output, chunk_size in zip(chunked_output, chunk_sizes)], dim=0)
result.append(output_)
else:
result.append(np.mean([chunk_output.item() for chunk_output in chunked_output]))
if isinstance(chunked_output[0], np.ndarray):
result.append(np.mean(chunked_output))
else:
result.append(torch.mean(torch.stack(chunked_output)))

if single_output:
result = result[0]
Expand Down

0 comments on commit 84e8ab1

Please sign in to comment.