Skip to content

Commit

Permalink
Speed up linear classifier inference.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704378934
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 10, 2024
1 parent 025f7f4 commit 29a3ad6
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 36 deletions.
139 changes: 103 additions & 36 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
"""Functions for training and applying a linear classifier."""

import base64
from concurrent import futures
import dataclasses
import json
from typing import Any, Sequence
from typing import Any, Iterator, Sequence

from etils import epath
from hoplite.agile import classifier_data
from hoplite.agile import metrics
from hoplite.db import interface as db_interface
from hoplite.taxonomy import namespace
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -230,50 +231,116 @@ def train_step(y_true, embeddings, is_labeled_mask):
return linear_classifier, eval_scores


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
):
"""Write a CSV for all audio windows with logits above a threshold.
@dataclasses.dataclass
class CsvWorkerState:
"""State for the CSV worker.
Params:
db: The base database from the parent thread.
csv_filepath: The path to the CSV file to write.
labels: The labels to write.
threshold: The threshold for writing detections.
_thread_db: The database to use in child threads.
"""

db: db_interface.HopliteDBInterface
csv_filepath: str
labels: tuple[str, ...]
threshold: float
_thread_db: db_interface.HopliteDBInterface | None = None

def get_thread_db(self) -> db_interface.HopliteDBInterface:
if self._thread_db is None:
self._thread_db = self.db.thread_split()
return self._thread_db


def csv_worker_initializer(state: CsvWorkerState):
"""Initialize the CSV worker."""
state.get_thread_db()
with epath.Path(state.csv_filepath).open('w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')


def csv_worker_fn(
embedding_ids: np.ndarray, logits: np.ndarray, state: CsvWorkerState
) -> None:
"""Writes a CSV row for each detection.
Args:
params: The parameters of the linear classifier.
class_list: The class list of labels associated with the classifier.
db: HopliteDBInterface to read embeddings from.
output_filepath: Path to write the CSV to.
threshold: Logits must be above this value to be written.
labels: If provided, only write logits for these labels. If None, write
logits for all labels.
Returns:
None
embedding_ids: The embedding ids to write.
logits: The logits for each embedding id.
state: The state of the worker.
"""
idxes = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda emb: linear_classifier(emb)[target_label_ids]
detection_count = 0
with open(output_filepath, 'w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')
for idx in tqdm.tqdm(idxes):
db = state.get_thread_db()
with epath.Path(state.csv_filepath).open('a') as f:
for idx, logit in zip(embedding_ids, logits):
source = db.get_embedding_source(idx)
emb = db.get_embedding(idx)
logits = logits_fn(emb)
for a in np.argwhere(logits > threshold):
lbl = labels[a[0]]
for a in np.argwhere(logit > state.threshold):
lbl = state.labels[a[0]]
row = [
idx,
source.dataset_name,
source.source_id,
source.offsets[0],
lbl,
logits[a],
logit[a][0],
]
f.write(','.join(map(str, row)) + '\n')
detection_count += 1


def batched_embedding_iterator(
db: db_interface.HopliteDBInterface,
embedding_ids: np.ndarray,
batch_size: int = 1024,
) -> Iterator[tuple[np.ndarray, np.ndarray]]:
"""Iterate over embeddings in batches."""
for q in range(0, len(embedding_ids), batch_size):
batch_ids = embedding_ids[q : q + batch_size]
batch_ids, batch_embs = db.get_embeddings(batch_ids)
yield batch_ids, batch_embs


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
embedding_ids: np.ndarray | None = None,
) -> None:
"""Write a CSV for all audio windows with logits above a threshold."""
if embedding_ids is None:
embedding_ids = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
else:
labels = tuple(set(labels).intersection(linear_classifier.classes))
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda batch_embs: linear_classifier(batch_embs)[
:, target_label_ids
]
detection_count = 0
state = CsvWorkerState(
db=db,
csv_filepath=output_filepath,
labels=labels,
threshold=threshold,
)
emb_iter = batched_embedding_iterator(db, embedding_ids, batch_size=1024)
with futures.ThreadPoolExecutor(
max_workers=1,
initializer=csv_worker_initializer,
initargs=(state,),
) as executor:
for batch_idxes, batch_embs in tqdm.tqdm(emb_iter):
logits = logits_fn(batch_embs)
# Filter out rows with no detections, avoiding extra database retrievals.
detections = logits > threshold
keep_rows = detections.max(axis=1)
logits = logits[keep_rows]
kept_idxes = batch_idxes[keep_rows]
executor.submit(csv_worker_fn, kept_idxes, logits, state)
detection_count += detections.sum()
print(f'Wrote {detection_count} detections to {output_filepath}')
90 changes: 90 additions & 0 deletions hoplite/agile/tests/classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import tempfile

from hoplite.agile import classifier
from hoplite.agile import classifier_data
from hoplite.db.tests import test_utils as db_test_utils
from ml_collections import config_dict
import numpy as np
import pandas as pd

from absl.testing import absltest

Expand Down Expand Up @@ -69,6 +72,93 @@ def test_save_load_linear_classifier(self):
self.assertSequenceEqual(classy_loaded.classes, classy.classes)
self.assertEqual(classy_loaded.embedding_model_config.model_name, 'nelson')

def test_train_linear_classifier(self):
rng = np.random.default_rng(1234)
embedding_dim = 8
db = db_test_utils.make_db(
path=self.tempdir,
db_type='in_mem',
num_embeddings=1024,
rng=rng,
embedding_dim=embedding_dim,
)
db_test_utils.add_random_labels(
db, rng=rng, unlabeled_prob=0.5, positive_label_prob=0.1
)
data_manager = classifier_data.AgileDataManager(
target_labels=db_test_utils.CLASS_LABELS,
db=db,
train_ratio=0.8,
min_eval_examples=5,
batch_size=32,
weak_negatives_batch_size=10,
rng=np.random.default_rng(42),
)
lc, eval_scores = classifier.train_linear_classifier(
data_manager,
learning_rate=0.01,
weak_neg_weight=0.5,
num_train_steps=128,
loss='bce',
)
self.assertIsInstance(lc, classifier.LinearClassifier)
np.testing.assert_equal(
lc.beta.shape, (embedding_dim, len(db_test_utils.CLASS_LABELS))
)
np.testing.assert_equal(
lc.beta_bias.shape, (len(db_test_utils.CLASS_LABELS),)
)
self.assertIn('roc_auc', eval_scores)

def test_write_inference_csv(self):
embedding_dim = 8
rng = np.random.default_rng(1234)
db = db_test_utils.make_db(
path=self.tempdir,
db_type='in_mem',
num_embeddings=1024,
rng=rng,
embedding_dim=embedding_dim,
)
db_test_utils.add_random_labels(
db, rng=rng, unlabeled_prob=0.5, positive_label_prob=0.1
)
classes = ('alpha', 'beta', 'delta', 'epsilon')
classy = self._make_linear_classifier(embedding_dim, classes)
inference_classes = ('alpha', 'epsilon', 'gamma')
classy.beta_bias = 0.0
csv_filepath = os.path.join(self.tempdir, 'inference.csv')
classifier.write_inference_csv(
embedding_ids=db.get_embedding_ids(),
linear_classifier=classy,
db=db,
output_filepath=csv_filepath,
threshold=0.0,
labels=inference_classes,
)
inference_csv = pd.read_csv(csv_filepath)
got_labels = np.unique(inference_csv['label'].values)
# `gamma` is not in the inference_classes, so should not be in the output.
expected_labels = ('alpha', 'epsilon')
np.testing.assert_array_equal(got_labels, expected_labels)

# We can estimate the total number of detections. There are 1024 embeddings,
# and we will only have outputs for two classes. Each logit is > 0 with
# probability 0.5, because we are using an unbiased random classifier
# with random embeddings. So, we expect 1024 * 0.5 * 2 = 1024 detections.
self.assertGreater(len(inference_csv), 1000)
self.assertLess(len(inference_csv), 1050)

# Spot check some of the inference scores.
for i in range(16):
emb_id = inference_csv['idx'][i]
lbl = inference_csv['label'][i]
got_logit = inference_csv['logits'][i]
class_idx = classy.classes.index(lbl)
embedding = db.get_embedding(emb_id)
expect_logit = classy(embedding)[class_idx]
self.assertEqual(np.float16(got_logit), np.float16(expect_logit))


if __name__ == '__main__':
absltest.main()
4 changes: 4 additions & 0 deletions hoplite/db/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def make_db(
config = config_dict.ConfigDict()
config.embedding_dim = embedding_dim
db.insert_metadata('db_config', config)
model_config = config_dict.ConfigDict()
model_config.embedding_dim = embedding_dim
model_config.model_name = 'fake_model'
db.insert_metadata('model_config', model_config)
db.commit()
return db

Expand Down

0 comments on commit 29a3ad6

Please sign in to comment.