Skip to content

Commit

Permalink
Improve classwise metric logging
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519994859
  • Loading branch information
vdumoulin authored and copybara-github committed Mar 28, 2023
1 parent a91c68f commit 8af934a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 81 deletions.
12 changes: 4 additions & 8 deletions chirp/models/cmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,12 @@ def compute(self, sample_threshold: int = 0):
# Same as sklearn's average_precision_score(label, logits, average=None)
# but that implementation doesn't scale to 10k+ classes
class_aps = metrics.average_precision(
values["label_logits"][:, mask].T, values["label"][:, mask].T
values["label_logits"].T, values["label"].T
)
class_aps = jnp.where(mask, class_aps, jnp.nan)
return {
"macro": jnp.mean(class_aps),
**{
str(i): ap
for i, ap in zip(
jnp.arange(values["label"].shape[1])[mask], class_aps
)
},
"macro": jnp.mean(class_aps, where=mask),
"individual": class_aps,
}


Expand Down
51 changes: 23 additions & 28 deletions chirp/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,20 @@ def update_step(key, batch, train_state):
train_metrics, train_state = update_step(step_key, batch, train_state)

if step % log_every_steps == 0:
train_metrics = flax_utils.unreplicate(train_metrics).compute()
train_metrics = utils.flatten_dict(
flax_utils.unreplicate(train_metrics).compute()
)

metrics_kept = {}
for k, v in train_metrics.items():
if "xentropy" in k and not add_class_wise_metrics:
continue
metrics_kept[k] = v
train_metrics = metrics_kept
classwise_metrics = {
k: v for k, v in train_metrics.items() if "individual" in k
}
train_metrics = {
k: v for k, v in train_metrics.items() if k not in classwise_metrics
}

writer.write_scalars(step, utils.flatten_dict(train_metrics))
writer.write_scalars(step, train_metrics)
if add_class_wise_metrics:
writer.write_summaries(step, classwise_metrics)
reporter(step)

if (step + 1) % checkpoint_every_steps == 0 or step == num_train_steps:
Expand Down Expand Up @@ -361,26 +365,17 @@ def remainder_batch_fn(x):
break

# Log validation loss
valid_metrics = valid_metrics.compute()

if not add_class_wise_metrics:
metrics_kept = {}
for k, v in valid_metrics.items():
if "xentropy" in k:
# Only the class-wise xentropy metrics contain the string 'xentropy';
# the key corresponding to overall xentropy is called 'loss'.
continue
metrics_kept[k] = v
valid_metrics = metrics_kept

for k, v in valid_metrics.items():
# Only one of the keys of valid_metrics will contain the string 'cmap',
# and the associated value is a dict that has a 'macro' key as well as
# a key per class. To disable class-wise metrics, we keep only 'macro'.
if "_cmap" in k:
valid_metrics[k] = v["macro"]

writer.write_scalars(step, utils.flatten_dict(valid_metrics))
valid_metrics = utils.flatten_dict(valid_metrics.compute())
classwise_metrics = {
k: v for k, v in valid_metrics.items() if "individual" in k
}
valid_metrics = {
k: v for k, v in valid_metrics.items() if k not in classwise_metrics
}

writer.write_scalars(step, valid_metrics)
if add_class_wise_metrics:
writer.write_summaries(step, classwise_metrics)
writer.flush()


Expand Down
57 changes: 27 additions & 30 deletions chirp/train/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,16 +855,20 @@ def step(params, model_state):
)

if step % log_every_steps == 0:
train_metrics = flax_utils.unreplicate(train_metrics).compute()
train_metrics = utils.flatten_dict(
flax_utils.unreplicate(train_metrics).compute()
)

metrics_kept = {}
for k, v in train_metrics.items():
if "xentropy" in k and not add_class_wise_metrics:
continue
metrics_kept[k] = v
train_metrics = metrics_kept
classwise_metrics = {
k: v for k, v in train_metrics.items() if "individual" in k
}
train_metrics = {
k: v for k, v in train_metrics.items() if k not in classwise_metrics
}

writer.write_scalars(step, utils.flatten_dict(train_metrics))
writer.write_scalars(step, train_metrics)
if add_class_wise_metrics:
writer.write_summaries(step, classwise_metrics)
reporter(step)

if (step + 1) % checkpoint_every_steps == 0 or step == num_train_steps:
Expand Down Expand Up @@ -955,42 +959,35 @@ def get_metrics(batch, train_state, mask_key):
step = int(flax_utils.unreplicate(train_state.step))
key = model_bundle.key
with reporter.timed("eval"):
valid_metrics = flax_utils.replicate(valid_metrics_collection.empty())
valid_metrics = valid_metrics_collection.empty()
for s, batch in enumerate(valid_dataset.as_numpy_iterator()):
batch = jax.tree_map(np.asarray, batch)
mask_key = None
if mask_at_eval:
mask_key, key = random.split(key)
mask_key = random.split(mask_key, num=jax.local_device_count())
new_valid_metrics = get_metrics(batch, train_state, mask_key)
valid_metrics = valid_metrics.merge(new_valid_metrics)
valid_metrics = valid_metrics.merge(
flax_utils.unreplicate(new_valid_metrics)
)
if (
eval_steps_per_checkpoint is not None
and s >= eval_steps_per_checkpoint
):
break

# Log validation loss
valid_metrics = flax_utils.unreplicate(valid_metrics).compute()

if not add_class_wise_metrics:
metrics_kept = {}
for k, v in valid_metrics.items():
if "xentropy" in k:
# Only the class-wise xentropy metrics contain the string 'xentropy';
# the key corresponding to overall xentropy is called 'loss'.
continue
metrics_kept[k] = v
valid_metrics = metrics_kept

for k, v in valid_metrics.items():
# Only one of the keys of valid_metrics will contain the string 'cmap',
# and the associated value is a dict that has a 'macro' key as well as
# a key per class. To disable class-wise metrics, we keep only 'macro'.
if "_cmap" in k:
valid_metrics[k] = v["macro"]

writer.write_scalars(step, utils.flatten_dict(valid_metrics))
valid_metrics = utils.flatten_dict(valid_metrics.compute())
classwise_metrics = {
k: v for k, v in valid_metrics.items() if "individual" in k
}
valid_metrics = {
k: v for k, v in valid_metrics.items() if k not in classwise_metrics
}

writer.write_scalars(step, valid_metrics)
if add_class_wise_metrics:
writer.write_summaries(step, classwise_metrics)
writer.flush()


Expand Down
33 changes: 19 additions & 14 deletions chirp/train/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,30 +322,35 @@ def get_metrics(batch, train_state):
)

with reporter.timed('eval'):
valid_metrics = flax.jax_utils.replicate(valid_metrics_collection.empty())
valid_metrics = valid_metrics_collection.empty()
for valid_step, batch in enumerate(valid_dataset.as_numpy_iterator()):
batch = jax.tree_map(np.asarray, batch)
new_valid_metrics = get_metrics(batch, flax_utils.replicate(train_state))
valid_metrics = valid_metrics.merge(new_valid_metrics)
valid_metrics = valid_metrics.merge(
flax_utils.unreplicate(new_valid_metrics)
)
if (
eval_steps_per_checkpoint > 0
and valid_step >= eval_steps_per_checkpoint
):
break

# Log validation loss
valid_metrics = flax_utils.unreplicate(valid_metrics).compute()

if not add_class_wise_metrics:
metrics_kept = {}
for k, v in valid_metrics.items():
if '_cmap_' in k and not v.endswith('_cmap_macro'):
# Discard metrics like 'valid_cmap_442' keeping only 'valid_cmap_macro'.
continue
metrics_kept[k] = v
valid_metrics = metrics_kept
valid_metrics = {k.replace('___', '/'): v for k, v in valid_metrics.items()}
writer.write_scalars(int(train_state.step), utils.flatten_dict(valid_metrics))
valid_metrics = valid_metrics.compute()

valid_metrics = utils.flatten_dict(
{k.replace('___', '/'): v for k, v in valid_metrics.items()}
)
classwise_metrics = {
k: v for k, v in valid_metrics.items() if 'individual' in k
}
valid_metrics = {
k: v for k, v in valid_metrics.items() if k not in classwise_metrics
}

writer.write_scalars(int(train_state.step), valid_metrics)
if add_class_wise_metrics:
writer.write_summaries(int(train_state.step), classwise_metrics)
writer.flush()


Expand Down
2 changes: 1 addition & 1 deletion chirp/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def compute(self):
averages = self.total / self.count
return {
"mean": jnp.sum(self.total) / jnp.sum(self.count),
**{str(i): averages[i] for i in range(jnp.size(averages))},
"individual": averages,
}


Expand Down

0 comments on commit 8af934a

Please sign in to comment.