Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 5, 2024
1 parent 83f8397 commit e87ee00
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 36 deletions.
Binary file removed .coverage.haras-MacBook-Pro.local.65844.XelrAxox
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ danling/_version.py
# pytest
data
examples
.coverage*

# experiments
experiments
37 changes: 11 additions & 26 deletions danling/metrics/average_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,41 +219,26 @@ class MultiTaskAverageMeters(MultiTaskDict):
Examples:
>>> meters = MultiTaskAverageMeters()
>>> meters.update({"loss": 0.6, "dataset1.cls.auroc": 0.7, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.9})
>>> print(f"{meters:.4f}")
loss: 0.6000 (0.6000)
dataset1.cls.auroc: 0.7000 (0.7000)
dataset1.reg.r2: 0.8000 (0.8000)
dataset2.r2: 0.9000 (0.9000)
>>> f"{meters:.4f}"
'loss: 0.6000 (0.6000)\ndataset1.cls.auroc: 0.7000 (0.7000)\ndataset1.reg.r2: 0.8000 (0.8000)\ndataset2.r2: 0.9000 (0.9000)'
>>> meters['loss'].update(0.9, n=1)
>>> print(f"{meters:.4f}")
loss: 0.9000 (0.7500)
dataset1.cls.auroc: 0.7000 (0.7000)
dataset1.reg.r2: 0.8000 (0.8000)
dataset2.r2: 0.9000 (0.9000)
>>> f"{meters:.4f}"
'loss: 0.9000 (0.7500)\ndataset1.cls.auroc: 0.7000 (0.7000)\ndataset1.reg.r2: 0.8000 (0.8000)\ndataset2.r2: 0.9000 (0.9000)'
>>> meters.sum.dict()
{'loss': 1.5, 'dataset1': {'cls': {'auroc': 0.7}, 'reg': {'r2': 0.8}}, 'dataset2': {'r2': 0.9}}
>>> meters.count.dict()
{'loss': 2, 'dataset1': {'cls': {'auroc': 1}, 'reg': {'r2': 1}}, 'dataset2': {'r2': 1}}
>>> meters.reset()
>>> print(f"{meters:.4f}")
loss: 0.0000 (nan)
dataset1.cls.auroc: 0.0000 (nan)
dataset1.reg.r2: 0.0000 (nan)
dataset2.r2: 0.0000 (nan)
>>> f"{meters:.4f}"
'loss: 0.0000 (nan)\ndataset1.cls.auroc: 0.0000 (nan)\ndataset1.reg.r2: 0.0000 (nan)\ndataset2.r2: 0.0000 (nan)'
>>> meters = MultiTaskAverageMeters(return_average=True)
>>> meters.update({"loss": 0.6, "dataset1.a.auroc": 0.7, "dataset1.b.auroc": 0.8, "dataset2.auroc": 0.9})
>>> print(f"{meters:.4f}")
loss: 0.6000 (0.6000)
dataset1.a.auroc: 0.7000 (0.7000)
dataset1.b.auroc: 0.8000 (0.8000)
dataset2.auroc: 0.9000 (0.9000)
>>> f"{meters:.4f}"
'loss: 0.6000 (0.6000)\ndataset1.a.auroc: 0.7000 (0.7000)\ndataset1.b.auroc: 0.8000 (0.8000)\ndataset2.auroc: 0.9000 (0.9000)'
>>> meters.update({"loss": 0.9, "dataset1.a.auroc": 0.8, "dataset1.b.auroc": 0.9, "dataset2.auroc": 1.0})
>>> print(f"{meters:.4f}")
loss: 0.9000 (0.7500)
dataset1.a.auroc: 0.8000 (0.7500)
dataset1.b.auroc: 0.9000 (0.8500)
dataset2.auroc: 1.0000 (0.9500)
"""
>>> f"{meters:.4f}"
'loss: 0.9000 (0.7500)\ndataset1.a.auroc: 0.8000 (0.7500)\ndataset1.b.auroc: 0.9000 (0.8500)\ndataset2.auroc: 1.0000 (0.9500)'
""" # noqa: E501

@property
def sum(self) -> NestedDict[str, float]:
Expand Down
26 changes: 18 additions & 8 deletions danling/metrics/metric_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ class MetricMeters(AverageMeters):
>>> meters['auroc'].update([0.4, 0.8, 0.6, 0.2], [0, 1, 1, 0])
>>> meters.avg.dict()
{'acc': 0.6, 'auroc': 0.775, 'auprc': 0.55}
>>> meters.update(dict(loss=""))
>>> meters.update(dict(loss="")) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: MetricMeters.update() missing 1 required positional argument: 'target'
TypeError: ...update() missing 1 required positional argument: 'target'
"""

ignored_index: Optional[int] = None
Expand Down Expand Up @@ -200,11 +200,11 @@ class MultiTaskMetricMeters(MultiTaskAverageMeters):
)
('dataset2'): MetricMeters('acc',)
)
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset2": {"input": [0.1, 0.4, 0.6, 0.8], "target": [1, 0, 0, 0]}})
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset2": ([0.1, 0.4, 0.6, 0.8], [1, 0, 0, 0])})
>>> f"{metrics:.4f}"
'dataset1.cls: acc: 0.5000 (0.5000)\ndataset2: acc: 0.2500 (0.2500)'
>>> metrics.setattr("return_average", True)
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 0, 1]}})
>>> metrics.update({"dataset1.cls": [[0.1, 0.4, 0.6, 0.8], [0, 0, 1, 0]], "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 0, 1]}})
>>> f"{metrics:.4f}"
'dataset1.cls: acc: 0.7500 (0.6250)\ndataset2: acc: 0.7500 (0.5000)'
""" # noqa: E501
Expand All @@ -225,21 +225,31 @@ def update( # type: ignore[override] # pylint: disable=W0221
"""

for metric, value in values.items():
if isinstance(value, Mapping):
if isinstance(value, (Mapping, Sequence)):
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetricMeters):
for met in self[metric].all_values():
met.update(*value)
if isinstance(value, Mapping):
met.update(**value)
elif isinstance(value, Sequence):
met.update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (MetricMeters, MetricMeter)):
self[metric].update(*value)
if isinstance(value, Mapping):
self[metric].update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, "
"or MetricMeter, but got {type(self[metric])}"
)
else:
raise ValueError(f"Expected values to be a flat dictionary, but got {type(value)}")
raise ValueError(f"Expected values to be a Mapping or Sequence, but got {type(value)}")

# MultiTaskAverageMeters.get is hacked
def get(self, name: Any, default=None) -> Any:
Expand Down
9 changes: 8 additions & 1 deletion danling/optim/lr_scheduler/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ class LRScheduler(lr_scheduler._LRScheduler): # pylint: disable=protected-acces
... scheduler.step()
>>> [round(lr, 10) for lr in lrs]
[0.3330753446, 0.0187302031, 0.000533897, 3.00232e-05, 1e-09]
"""
>>> scheduler = LRScheduler(optimizer, total_steps=5, final_lr_ratio=1e-5, strategy='linear', method='numerical')
>>> lrs = []
>>> for epoch in range(5):
... lrs.append(scheduler.get_lr()[0])
... scheduler.step()
>>> [round(lr, 2) for lr in lrs]
[0.8, 0.6, 0.4, 0.2, 0.0]
""" # noqa: E501

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions danling/runner/accelerate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def train_epoch(self, split: str = "train") -> NestedDict:
interval = iteration - last_print_iteration
if self.device == torch.device("cuda"):
torch.cuda.synchronize()
if self.scheduler is not None:
self.meters.lr.update(self.scheduler.get_last_lr()[0])
self.meters.time.update((time() - batch_time) / interval)
batch_time = time()
reduced_loss = self.reduce(loss).item()
Expand Down
3 changes: 3 additions & 0 deletions demo/vision/torch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
self.optim.weight_decay = 1e-4

def post(self):
self.copy_class_attributes()
self.experiment_name = f"{self.network.name}_{self.optim.name}@{self.optim.lr}"


Expand All @@ -65,6 +66,8 @@ def __init__(self, config: Config):
self.model = getattr(torchvision.models, self.network.name)(pretrained=False, num_classes=10)
self.model.conv1 = nn.Conv2d(1, 64, 1, bias=False)
self.optimizer = OPTIMIZERS.build(params=self.model.parameters(), **self.optim)
train_steps = len(self.datasets.train) // self.dataloader.batch_size * (self.epoch_end - self.epoch_begin)
self.scheduler = dl.optim.lr_scheduler.LRScheduler(self.optimizer, total_steps=train_steps)
self.criterion = nn.CrossEntropyLoss()

self.metrics = dl.metrics.multiclass_metrics(num_classes=10)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dynamic = [
]
dependencies = [
"cached-property; python_version<'3.8'",
"chanfig>=0.0.96",
"chanfig>=0.0.96,!=0.0.101,!=0.0.102",
"gitpython",
"lazy-imports",
"strenum; python_version<'3.11'",
Expand Down

0 comments on commit e87ee00

Please sign in to comment.