Skip to content

Commit

Permalink
remove cache system in metrics
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Oct 13, 2024
1 parent e1d38ef commit 18a401a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 176 deletions.
1 change: 1 addition & 0 deletions danling/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .functional import accuracy, auprc, auroc, f1_score, mcc, pearson, r2_score, rmse, spearman
from .metrics import Metrics, MultiTaskMetrics


__all__ = [
"Metrics",
"MultiTaskMetrics",
Expand Down
183 changes: 53 additions & 130 deletions danling/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class Metrics(Metric):
metrics: A dictionary of metrics to be computed.A
ignored_index: Index to be ignored in the computation.
val: Metric results of current batch on current device.
bat: Metric results of current batch on all devices.
avg: Metric results of all results on all devices.
input: The input tensor of latest batch.
target: The target tensor of latest batch.
Expand Down Expand Up @@ -87,11 +86,6 @@ class Metrics(Metric):
('auroc'): 0.75
('auprc'): 0.8333333730697632
)
>>> metrics.bat # Metrics of current batch on all devices
NestedDict(
('auroc'): 0.75
('auprc'): 0.8333333730697632
)
>>> metrics.avg # Metrics of all data on all devices
NestedDict(
('auroc'): 0.75
Expand All @@ -111,11 +105,6 @@ class Metrics(Metric):
('auroc'): 0.6666666666666666
('auprc'): 0.5
)
>>> metrics.bat # Metrics of current batch on all devices
NestedDict(
('auroc'): 0.6666666666666666
('auprc'): 0.5
)
>>> metrics.avg # Metrics of all data on all devices
NestedDict(
('auroc'): 0.6666666666666666
Expand All @@ -134,14 +123,13 @@ class Metrics(Metric):
ignored_index: Optional[int] = None
_input: Tensor
_target: Tensor
_inputs: flist
_targets: flist
_input_buffer: flist
_target_buffer: flist
_inputs: Tensor
_targets: Tensor
score_name: str
best_fn: Callable
merge_dict: bool = True
return_nested: bool = False
flatten: bool = False

def __init__(
self,
Expand All @@ -156,10 +144,9 @@ def __init__(
super().__init__(device=device)
self._add_state("_input", torch.empty(0))
self._add_state("_target", torch.empty(0))
self._add_state("_inputs", flist())
self._add_state("_targets", flist())
self._add_state("_input_buffer", flist())
self._add_state("_target_buffer", flist())
self._add_state("_inputs", torch.empty(0))
self._add_state("_targets", torch.empty(0))
self.world_size = get_world_size()
self.metrics = FlatDict(*args, **metrics)
self.preprocess = preprocess
if merge_dict is not None:
Expand All @@ -168,7 +155,6 @@ def __init__(
self.return_nested = return_nested
self.ignored_index = ignored_index

@torch.inference_mode()
def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | NestedTensor | Sequence) -> None:
# convert input and target to Tensor if they are not
if not isinstance(input, (Tensor, NestedTensor)):
Expand All @@ -185,10 +171,20 @@ def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | Neste
input = input.squeeze(-1)
# convert input and target to NestedTensor if one of them is
if isinstance(input, NestedTensor) or isinstance(target, NestedTensor):
if isinstance(input, NestedTensor) and isinstance(target, Tensor):
target = input.nested_like(target, strict=False)
if isinstance(target, NestedTensor) and isinstance(input, Tensor):
input = target.nested_like(input, strict=False)
if isinstance(target, NestedTensor) and isinstance(input, NestedTensor):
input, target = input.concat, target.concat
elif isinstance(input, NestedTensor):
input, mask = input.concat, input.mask
target = target[mask]
elif isinstance(target, NestedTensor):
target, mask = target.concat, target.mask
input = input[mask]
else:
raise ValueError(f"Unknown input and target: {input}, {target}")
self.flatten = True
elif self.flatten:
target = target.flatten()
input = input.view(*target.shape, -1)
# remove ignored index
if self.ignored_index is not None:
if isinstance(input, NestedTensor):
Expand All @@ -197,40 +193,27 @@ def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | Neste
target = NestedTensor([t[i] for t, i in zip(target.storage(), indices)])
else:
input, target = input[target != self.ignored_index], target[target != self.ignored_index]
# update internal state
if isinstance(input, NestedTensor):
self._input = input
self._input_buffer.extend(input.detach().cpu().storage()) # type: ignore[union-attr]
self._target = target
self._target_buffer.extend(target.detach().cpu().storage()) # type: ignore[union-attr]
else:
self._input = input
self._input_buffer.append(input.detach().cpu()) # type: ignore[union-attr]
self._target = target
self._target_buffer.append(target.detach().cpu()) # type: ignore[union-attr]

def compute(self) -> NestedDict[str, float | flist]:
return self.calculate(self.inputs.to(self.device), self.targets.to(self.device))
if self.world_size > 1:
input, target = self._sync(input), self._sync(target)
input, target = input.detach().to(self.device), target.detach().to(self.device)
self._input = input
self._target = target
self._inputs = torch.cat([self._inputs, input]).to(input.dtype)
self._targets = torch.cat([self._targets, target]).to(target.dtype)

def value(self) -> NestedDict[str, float | flist]:
input = self._input.concat if isinstance(self._input, NestedTensor) else self._input
target = self._target.concat if isinstance(self._target, NestedTensor) else self._target
return self.calculate(input, target)

def batch(self) -> NestedDict[str, float | flist]:
return self.calculate(self.input, self.target)

def average(self) -> NestedDict[str, float | flist]:
return self.calculate(self.inputs.to(self.device), self.targets.to(self.device))
return self.calculate(self.inputs, self.targets)

def compute(self) -> NestedDict[str, float | flist]:
return self.average()

@property
def val(self) -> NestedDict[str, float | flist]:
return self.value()

@property
def bat(self) -> NestedDict[str, float | flist]:
return self.batch()

@property
def avg(self) -> NestedDict[str, float | flist]:
return self.average()
Expand Down Expand Up @@ -271,98 +254,35 @@ def _calculate(self, metric, input: Tensor, target: Tensor, preprocess: bool = T
def merge_state(self, metrics: Iterable):
raise NotImplementedError()

# Due to an issue with PyTorch, we cannot decorate input/target with @torch.inference_mode()
# Otherwise, we will encounter the following error when using "gloo" backend:
# Inplace update to inference tensor outside InferenceMode is not allowed
def _sync(self, tensor: Tensor):
local_size = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=tensor.device)
size_list = [torch.zeros_like(local_size) for _ in range(self.world_size)]
dist.all_gather(size_list, local_size)
sizes = torch.cat(size_list)
max_size = sizes.max()

padded_tensor = torch.empty((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
padded_tensor[: tensor.shape[0]] = tensor
gathered_tensors = [torch.empty_like(padded_tensor) for _ in range(self.world_size)]
dist.all_gather(gathered_tensors, padded_tensor)
slices = [gathered_tensors[i][: sizes[i]] for i in range(self.world_size) if sizes[i] > 0]
return torch.cat(slices, dim=0)

@property
def input(self):
world_size = get_world_size()
if world_size == 1:
if isinstance(self._input, NestedTensor) and not self.return_nested:
return self._input.concat
return self._input
if isinstance(self._input, Tensor):
synced_tensor = [torch.zeros_like(self._input) for _ in range(world_size)]
dist.all_gather(synced_tensor, self._input)
return torch.cat(synced_tensor, 0)
if isinstance(self._input, NestedTensor):
synced_tensors = [None for _ in range(world_size)]
dist.all_gather_object(synced_tensors, self._input.storage())
synced_tensors = flist(i.to(self.device) for j in synced_tensors for i in j)
try:
return torch.cat(synced_tensors, 0)
except RuntimeError:
input = NestedTensor(synced_tensors)
if self.return_nested:
return input
return input.concat
raise ValueError(f"Expected _input to be a Tensor or a NestedTensor, but got {type(self._input)}")
return self._input

@property
def target(self):
world_size = get_world_size()
if world_size == 1:
if isinstance(self._target, NestedTensor) and not self.return_nested:
return self._target.concat
return self._target
if isinstance(self._target, Tensor):
synced_tensor = [torch.zeros_like(self._target) for _ in range(world_size)]
dist.all_gather(synced_tensor, self._target)
return torch.cat(synced_tensor, 0)
if isinstance(self._target, NestedTensor):
synced_tensors = [None for _ in range(world_size)]
dist.all_gather_object(synced_tensors, self._target.storage())
synced_tensors = flist(i.to(self.device) for j in synced_tensors for i in j)
try:
return torch.cat(synced_tensors, 0)
except RuntimeError:
target = NestedTensor(synced_tensors)
if self.return_nested:
return target
return target.concat
raise ValueError(f"Expected _target to be a Tensor or a NestedTensor, but got {type(self._target)}")
return self._target

@property
def inputs(self):
if not self._inputs and not self._input_buffer:
return torch.empty(0)
if self._input_buffer:
world_size = get_world_size()
if world_size > 1:
synced_tensors = [None for _ in range(world_size)]
dist.all_gather_object(synced_tensors, self._input_buffer)
self._inputs.extend([i for j in synced_tensors for i in j])
else:
self._inputs.extend(self._input_buffer)
self._input_buffer = flist()
try:
return torch.cat(self._inputs, 0)
except RuntimeError:
inputs = NestedTensor(self._inputs)
if self.return_nested:
return inputs
return inputs.concat
return self._inputs

@property
def targets(self):
if not self._targets and not self._target_buffer:
return torch.empty(0)
if self._target_buffer:
world_size = get_world_size()
if world_size > 1:
synced_tensors = [None for _ in range(world_size)]
dist.all_gather_object(synced_tensors, self._target_buffer)
self._targets.extend([i for j in synced_tensors for i in j])
else:
self._targets.extend(self._target_buffer)
self._target_buffer = flist()
try:
return torch.cat(self._targets, 0)
except RuntimeError:
targets = NestedTensor(self._targets)
if self.return_nested:
return targets
return targets.concat
return self._targets

def __repr__(self):
keys = tuple(i for i in self.metrics.keys())
Expand Down Expand Up @@ -504,7 +424,10 @@ class MultiTaskMetrics(MultiTaskDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, default_factory=MultiTaskMetrics, **kwargs)

def update(self, values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence]]) -> None:
def update(
self,
values: Mapping[str, Mapping[str, Tensor | NestedTensor | Sequence]],
) -> None:
r"""
Updates the average and current value in all metrics.
Expand Down
Loading

0 comments on commit 18a401a

Please sign in to comment.