Skip to content

Commit

Permalink
get preds off gpu to improve perf (#269)
Browse files Browse the repository at this point in the history
We were leaving the label predictions tensors on the GPU,
which was leading to lots of expensive GPU calls to read
the data. In particular, `tt profile` revealed the following:

![Screen Shot 2023-07-19 at 11 41 06 AM](https://github.com/allenai/mmda/assets/1287054/9f6736d2-51be-4d65-8bfb-fe93cca9eed7)

Scalene's GPU time reporting is generally full of false attribution,
but it was a hint in the right direction in this case.
Based on my interpretation of the code, we were being forced
to access this tensor off the GPU three times for _every_ input
word in the document.

After pulling the label preds into system memory:

![Screen Shot 2023-07-19 at 11 44 24 AM](https://github.com/allenai/mmda/assets/1287054/5853924a-824b-4818-aeba-fe121d48c9af)

I confirmed the remaining high GPU items reported by the
profiler have nothing to do with the GPU, 
and the code is too scary to futz about with anyway.
  • Loading branch information
cmwilhelm authored Jul 19, 2023
1 parent 32f8fbd commit aaf121d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'mmda'
version = '0.9.5'
version = '0.9.6'
description = 'MMDA - multimodal document analysis'
authors = [
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'},
Expand Down
2 changes: 1 addition & 1 deletion src/mmda/predictors/hf_predictors/mention_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def predict_page(self, page: Annotation, counter: Iterator[int], print_warnings:
)
batch.to(self.model.device)
batch_outputs = self.model(**batch)
batch_prediction_label_ids = torch.argmax(batch_outputs.logits, dim=-1)[0]
batch_prediction_label_ids = torch.argmax(batch_outputs.logits, dim=-1).tolist()[0]
prediction_label_ids.append(batch_prediction_label_ids)

def has_label_id(lbls: List[int], want_label_id: int) -> bool:
Expand Down

0 comments on commit aaf121d

Please sign in to comment.