diff --git a/hoplite/agile/classifier.py b/hoplite/agile/classifier.py index 71dc763..1db9302 100644 --- a/hoplite/agile/classifier.py +++ b/hoplite/agile/classifier.py @@ -16,14 +16,15 @@ """Functions for training and applying a linear classifier.""" import base64 +from concurrent import futures import dataclasses import json -from typing import Any, Sequence +from typing import Any, Iterator, Sequence +from etils import epath from hoplite.agile import classifier_data from hoplite.agile import metrics from hoplite.db import interface as db_interface -from hoplite.taxonomy import namespace from ml_collections import config_dict import numpy as np import tensorflow as tf @@ -230,50 +231,116 @@ def train_step(y_true, embeddings, is_labeled_mask): return linear_classifier, eval_scores -def write_inference_csv( - linear_classifier: LinearClassifier, - db: db_interface.HopliteDBInterface, - output_filepath: str, - threshold: float, - labels: Sequence[str] | None = None, -): - """Write a CSV for all audio windows with logits above a threshold. +@dataclasses.dataclass +class CsvWorkerState: + """State for the CSV worker. + + Params: + db: The base database from the parent thread. + csv_filepath: The path to the CSV file to write. + labels: The labels to write. + threshold: The threshold for writing detections. + _thread_db: The database to use in child threads. + """ + + db: db_interface.HopliteDBInterface + csv_filepath: str + labels: tuple[str, ...] + threshold: float + _thread_db: db_interface.HopliteDBInterface | None = None + + def get_thread_db(self) -> db_interface.HopliteDBInterface: + if self._thread_db is None: + self._thread_db = self.db.thread_split() + return self._thread_db + + +def csv_worker_initializer(state: CsvWorkerState): + """Initialize the CSV worker.""" + state.get_thread_db() + with epath.Path(state.csv_filepath).open('w') as f: + f.write('idx,dataset_name,source_id,offset,label,logits\n') + + +def csv_worker_fn( + embedding_ids: np.ndarray, logits: np.ndarray, state: CsvWorkerState +) -> None: + """Writes a CSV row for each detection. Args: - params: The parameters of the linear classifier. - class_list: The class list of labels associated with the classifier. - db: HopliteDBInterface to read embeddings from. - output_filepath: Path to write the CSV to. - threshold: Logits must be above this value to be written. - labels: If provided, only write logits for these labels. If None, write - logits for all labels. - - Returns: - None + embedding_ids: The embedding ids to write. + logits: The logits for each embedding id. + state: The state of the worker. """ - idxes = db.get_embedding_ids() - if labels is None: - labels = linear_classifier.classes - label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)} - target_label_ids = np.array([label_ids[l] for l in labels]) - logits_fn = lambda emb: linear_classifier(emb)[target_label_ids] - detection_count = 0 - with open(output_filepath, 'w') as f: - f.write('idx,dataset_name,source_id,offset,label,logits\n') - for idx in tqdm.tqdm(idxes): + db = state.get_thread_db() + with epath.Path(state.csv_filepath).open('a') as f: + for idx, logit in zip(embedding_ids, logits): source = db.get_embedding_source(idx) - emb = db.get_embedding(idx) - logits = logits_fn(emb) - for a in np.argwhere(logits > threshold): - lbl = labels[a[0]] + for a in np.argwhere(logit > state.threshold): + lbl = state.labels[a[0]] row = [ idx, source.dataset_name, source.source_id, source.offsets[0], lbl, - logits[a], + logit[a][0], ] f.write(','.join(map(str, row)) + '\n') - detection_count += 1 + + +def batched_embedding_iterator( + db: db_interface.HopliteDBInterface, + embedding_ids: np.ndarray, + batch_size: int = 1024, +) -> Iterator[tuple[np.ndarray, np.ndarray]]: + """Iterate over embeddings in batches.""" + for q in range(0, len(embedding_ids), batch_size): + batch_ids = embedding_ids[q : q + batch_size] + batch_ids, batch_embs = db.get_embeddings(batch_ids) + yield batch_ids, batch_embs + + +def write_inference_csv( + linear_classifier: LinearClassifier, + db: db_interface.HopliteDBInterface, + output_filepath: str, + threshold: float, + labels: Sequence[str] | None = None, + embedding_ids: np.ndarray | None = None, +) -> None: + """Write a CSV for all audio windows with logits above a threshold.""" + if embedding_ids is None: + embedding_ids = db.get_embedding_ids() + if labels is None: + labels = linear_classifier.classes + else: + labels = tuple(set(labels).intersection(linear_classifier.classes)) + label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)} + target_label_ids = np.array([label_ids[l] for l in labels]) + logits_fn = lambda batch_embs: linear_classifier(batch_embs)[ + :, target_label_ids + ] + detection_count = 0 + state = CsvWorkerState( + db=db, + csv_filepath=output_filepath, + labels=labels, + threshold=threshold, + ) + emb_iter = batched_embedding_iterator(db, embedding_ids, batch_size=1024) + with futures.ThreadPoolExecutor( + max_workers=1, + initializer=csv_worker_initializer, + initargs=(state,), + ) as executor: + for batch_idxes, batch_embs in tqdm.tqdm(emb_iter): + logits = logits_fn(batch_embs) + # Filter out rows with no detections, avoiding extra database retrievals. + detections = logits > threshold + keep_rows = detections.max(axis=1) + logits = logits[keep_rows] + kept_idxes = batch_idxes[keep_rows] + executor.submit(csv_worker_fn, kept_idxes, logits, state) + detection_count += detections.sum() print(f'Wrote {detection_count} detections to {output_filepath}') diff --git a/hoplite/agile/tests/classifier_test.py b/hoplite/agile/tests/classifier_test.py index c3b4313..931f2e2 100644 --- a/hoplite/agile/tests/classifier_test.py +++ b/hoplite/agile/tests/classifier_test.py @@ -19,8 +19,11 @@ import tempfile from hoplite.agile import classifier +from hoplite.agile import classifier_data +from hoplite.db.tests import test_utils as db_test_utils from ml_collections import config_dict import numpy as np +import pandas as pd from absl.testing import absltest @@ -69,6 +72,93 @@ def test_save_load_linear_classifier(self): self.assertSequenceEqual(classy_loaded.classes, classy.classes) self.assertEqual(classy_loaded.embedding_model_config.model_name, 'nelson') + def test_train_linear_classifier(self): + rng = np.random.default_rng(1234) + embedding_dim = 8 + db = db_test_utils.make_db( + path=self.tempdir, + db_type='in_mem', + num_embeddings=1024, + rng=rng, + embedding_dim=embedding_dim, + ) + db_test_utils.add_random_labels( + db, rng=rng, unlabeled_prob=0.5, positive_label_prob=0.1 + ) + data_manager = classifier_data.AgileDataManager( + target_labels=db_test_utils.CLASS_LABELS, + db=db, + train_ratio=0.8, + min_eval_examples=5, + batch_size=32, + weak_negatives_batch_size=10, + rng=np.random.default_rng(42), + ) + lc, eval_scores = classifier.train_linear_classifier( + data_manager, + learning_rate=0.01, + weak_neg_weight=0.5, + num_train_steps=128, + loss='bce', + ) + self.assertIsInstance(lc, classifier.LinearClassifier) + np.testing.assert_equal( + lc.beta.shape, (embedding_dim, len(db_test_utils.CLASS_LABELS)) + ) + np.testing.assert_equal( + lc.beta_bias.shape, (len(db_test_utils.CLASS_LABELS),) + ) + self.assertIn('roc_auc', eval_scores) + + def test_write_inference_csv(self): + embedding_dim = 8 + rng = np.random.default_rng(1234) + db = db_test_utils.make_db( + path=self.tempdir, + db_type='in_mem', + num_embeddings=1024, + rng=rng, + embedding_dim=embedding_dim, + ) + db_test_utils.add_random_labels( + db, rng=rng, unlabeled_prob=0.5, positive_label_prob=0.1 + ) + classes = ('alpha', 'beta', 'delta', 'epsilon') + classy = self._make_linear_classifier(embedding_dim, classes) + inference_classes = ('alpha', 'epsilon', 'gamma') + classy.beta_bias = 0.0 + csv_filepath = os.path.join(self.tempdir, 'inference.csv') + classifier.write_inference_csv( + embedding_ids=db.get_embedding_ids(), + linear_classifier=classy, + db=db, + output_filepath=csv_filepath, + threshold=0.0, + labels=inference_classes, + ) + inference_csv = pd.read_csv(csv_filepath) + got_labels = np.unique(inference_csv['label'].values) + # `gamma` is not in the inference_classes, so should not be in the output. + expected_labels = ('alpha', 'epsilon') + np.testing.assert_array_equal(got_labels, expected_labels) + + # We can estimate the total number of detections. There are 1024 embeddings, + # and we will only have outputs for two classes. Each logit is > 0 with + # probability 0.5, because we are using an unbiased random classifier + # with random embeddings. So, we expect 1024 * 0.5 * 2 = 1024 detections. + self.assertGreater(len(inference_csv), 1000) + self.assertLess(len(inference_csv), 1050) + + # Spot check some of the inference scores. + for i in range(16): + emb_id = inference_csv['idx'][i] + lbl = inference_csv['label'][i] + got_logit = inference_csv['logits'][i] + class_idx = classy.classes.index(lbl) + embedding = db.get_embedding(emb_id) + expect_logit = classy(embedding)[class_idx] + self.assertEqual(np.float16(got_logit), np.float16(expect_logit)) + if __name__ == '__main__': absltest.main() diff --git a/hoplite/db/tests/test_utils.py b/hoplite/db/tests/test_utils.py index 42cd0c7..307890f 100644 --- a/hoplite/db/tests/test_utils.py +++ b/hoplite/db/tests/test_utils.py @@ -52,6 +52,10 @@ def make_db( config = config_dict.ConfigDict() config.embedding_dim = embedding_dim db.insert_metadata('db_config', config) + model_config = config_dict.ConfigDict() + model_config.embedding_dim = embedding_dim + model_config.model_name = 'fake_model' + db.insert_metadata('model_config', model_config) db.commit() return db