Skip to content

Commit

Permalink
Fixed bottleneck in PairwiseReranker
Browse files Browse the repository at this point in the history
Fixed bottleneck in PairwiseReranker
  • Loading branch information
AlekseySh authored Jun 17, 2024
1 parent 4872972 commit 0e3776f
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 22 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,7 @@ for our paper
## [Installation](https://open-metric-learning.readthedocs.io/en/latest/oml/installation.html)

```shell
pip install -U open-metric-learning
```

If you need OML for NLP, install the extra requirements with:
```shell
pip install -U open-metric-learning; # minimum dependencies
pip install -U open-metric-learning[nlp]
```

Expand Down
6 changes: 1 addition & 5 deletions docs/readme/installation.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
```shell
pip install -U open-metric-learning
```

If you need OML for NLP, install the extra requirements with:
```shell
pip install -U open-metric-learning; # minimum dependencies
pip install -U open-metric-learning[nlp]
```

Expand Down
7 changes: 5 additions & 2 deletions oml/datasets/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,14 @@ def __init__(
labels_key=labels_key,
)

self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()

def get_query_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
return self._query_ids

def get_gallery_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
return self._gallery_ids


class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset):
Expand Down
7 changes: 5 additions & 2 deletions oml/datasets/texts.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,14 @@ def __init__(
index_key=index_key,
)

self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()

def get_query_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
return self._query_ids

def get_gallery_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
return self._gallery_ids


class TextQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset):
Expand Down
6 changes: 4 additions & 2 deletions oml/retrieval/postprocessors/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def _process_raw(
# Queries have different number of retrieved items, so we track what pairs are relevant to what queries (bounds)
pairs = []
bounds = [0]
query_ids = dataset.get_query_ids()
gallery_ids = dataset.get_gallery_ids()
for iq, ids_gallery in enumerate(retrieved_ids):
ids_gallery_global = dataset.get_gallery_ids()[ids_gallery][: self.top_n].tolist()
ids_query_global = [dataset.get_query_ids()[iq].item()] * len(ids_gallery_global)
ids_gallery_global = gallery_ids[ids_gallery][: self.top_n].tolist()
ids_query_global = [query_ids[iq].item()] * len(ids_gallery_global)

pairs.extend(list(zip(ids_query_global, ids_gallery_global)))
bounds.append(bounds[-1] + len(ids_gallery_global))
Expand Down
7 changes: 5 additions & 2 deletions tests/test_integrations/test_lightning/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, labels: List[int], im_size: int):
self.im_size = im_size
self.extra_data = dict()

self._query_ids = torch.arange(len(self)).long()
self._gallery_ids = torch.arange(len(self)).long()

def __getitem__(self, item: int) -> Dict[str, Any]:
input_tensors = torch.rand((3, self.im_size, self.im_size))
label = torch.tensor(self.labels[item]).long()
Expand All @@ -40,10 +43,10 @@ def get_labels(self) -> np.ndarray:
return np.array(self.labels)

def get_query_ids(self) -> LongTensor:
return torch.arange(len(self)).long()
return self._query_ids

def get_gallery_ids(self) -> LongTensor:
return torch.arange(len(self)).long()
return self._gallery_ids


class DummyCommonModule(pl.LightningModule):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_integrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
assert len(embeddings) == len(is_query) == len(is_gallery)

self._embeddings = embeddings
self._is_query = is_query
self._is_gallery = is_gallery
self._query_ids = is_query.nonzero().squeeze()
self._gallery_ids = is_gallery.nonzero().squeeze()

self.extra_data = {}
if categories is not None:
Expand Down Expand Up @@ -81,10 +81,10 @@ def __len__(self) -> int:
return len(self._embeddings)

def get_query_ids(self) -> LongTensor:
return self._is_query.nonzero().squeeze()
return self._query_ids

def get_gallery_ids(self) -> LongTensor:
return self._is_gallery.nonzero().squeeze()
return self._gallery_ids


class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset):
Expand Down

0 comments on commit 0e3776f

Please sign in to comment.