Skip to content

Commit

Permalink
List of tensors gets concatenated once at the end
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfleming committed Aug 25, 2023
1 parent 16c7f72 commit 86a0c90
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions cellbender/remove_background/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo(
f'accurate for your dataset.')
raise RuntimeError('Zero cells found!')

dataloader_index_to_analyzed_bc_index = np.where(cell_logic)[0]
dataloader_index_to_analyzed_bc_index = torch.where(torch.tensor(cell_logic))[0]
cell_data_loader = DataLoader(
count_matrix[cell_logic],
empty_drop_dataset=None,
Expand All @@ -468,6 +468,11 @@ def _get_cell_noise_count_posterior_coo(
log_probs = []
ind = 0
n_minibatches = len(cell_data_loader)
analyzed_gene_inds = torch.tensor(self.analyzed_gene_inds.copy())
if analyzed_bcs_only:
barcode_inds = torch.tensor(self.dataset_obj.analyzed_barcode_inds.copy())
else:
barcode_inds = torch.tensor(self.barcode_inds.copy())

logger.info('Computing posterior noise count probabilities in mini-batches.')

Expand Down Expand Up @@ -505,7 +510,7 @@ def _get_cell_noise_count_posterior_coo(
)

# Get the original gene index from gene index in the trimmed dataset.
genes_i = self.analyzed_gene_inds[genes_i_analyzed]
genes_i = analyzed_gene_inds[genes_i_analyzed.cpu()]

# Barcode index in the dataloader.
bcs_i = bcs_i_chunk + ind
Expand All @@ -514,37 +519,47 @@ def _get_cell_noise_count_posterior_coo(
bcs_i = dataloader_index_to_analyzed_bc_index[bcs_i]

# Translate chunk barcode inds to overall inds.
if analyzed_bcs_only:
bcs_i = self.dataset_obj.analyzed_barcode_inds[bcs_i]
else:
bcs_i = self.barcode_inds[bcs_i]
bcs_i = barcode_inds[bcs_i.cpu()]

# Add sparse matrix values to lists.
try:
bcs.extend(bcs_i.tolist())
genes.extend(genes_i.tolist())
c.extend(c_i.tolist())
log_probs.extend(log_prob_i.tolist())
c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
except TypeError as e:
# edge case of a single value
bcs.append(bcs_i)
genes.append(genes_i)
c.append(c_i)
log_probs.append(log_prob_i)
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
bcs.append(bcs_i.detach().cpu())
genes.append(genes_i.detach().cpu())
c.append(c_i.detach().cpu())
log_probs.append(log_prob_i.detach().cpu())
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().cpu())

# try:
# bcs.extend(bcs_i.tolist())
# genes.extend(genes_i.tolist())
# c.extend(c_i.tolist())
# log_probs.extend(log_prob_i.tolist())
# c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
# .detach().cpu().numpy())
# except TypeError as e:
# # edge case of a single value
# bcs.append(bcs_i)
# genes.append(genes_i)
# c.append(c_i)
# log_probs.append(log_prob_i)
# c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
# .detach().cpu().numpy())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size

# Convert the lists to numpy arrays.
log_probs = np.array(log_probs, dtype=float)
c = np.array(c, dtype=np.uint32)
barcodes = np.array(bcs, dtype=np.uint64) # uint32 is too small!
genes = np.array(genes, dtype=np.uint64) # use same as above for IndexConverter
noise_count_offsets = np.array(c_offset, dtype=np.uint32)
# # Convert the lists to numpy arrays.
# log_probs = np.array(log_probs, dtype=float)
# c = np.array(c, dtype=np.uint32)
# barcodes = np.array(bcs, dtype=np.uint64) # uint32 is too small!
# genes = np.array(genes, dtype=np.uint64) # use same as above for IndexConverter
# noise_count_offsets = np.array(c_offset, dtype=np.uint32)

# Concatenate lists.
log_probs = torch.cat(log_probs)
c = torch.cat(c)
barcodes = torch.cat(bcs)
genes = torch.cat(genes)
noise_count_offsets = torch.cat(c_offset)

# Translate (barcode, gene) inds to 'm' format index.
m = self.index_converter.get_m_indices(cell_inds=barcodes, gene_inds=genes)
Expand Down

0 comments on commit 86a0c90

Please sign in to comment.