Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.impute consumes too much memory #729

Open
Marius1311 opened this issue Jul 5, 2024 · 5 comments
Open

.impute consumes too much memory #729

Marius1311 opened this issue Jul 5, 2024 · 5 comments
Assignees

Comments

@Marius1311
Copy link
Collaborator

Marius1311 commented Jul 5, 2024

I'm trying to call problem.impute() on a solved (linear) spatial mapping problem of dimensions n_source=17806 (spatial data) by n_target=13298 (single-cell data) for n_genes=2039. This is just a full-rank Sinkhorn problem with batch_size=None.

Under the hood, this evaluates:

 predictions = [val.to(device=device).pull(gexp_sc, scale_by_marginals=True) for val in self.solutions.values()]

The pull amounts to a matrix multiplication: prediction = P @ X for transport matrix of shape 17806 x 13298 and single-cell GEX matrix X of shape 13298 x 2039. Thus, the memory bottleneck should be P, which is stored as float32 and should thus consume around 903 MB of memory. However, the call to impute fails (see traceback below) as it requests 1.76TiB of memory. That's because it tries to create an array of shape Shape: f32[2039,17806,13298], which is not needed for this operation.

Note that passing a batch size does not help much - let's say I'm passing batch_size=500, then this would still request an array of shape 2039 x 500 x 13298, which still requires over 50GB of memory. Also, this this slows down solving the actual OT problem, which would not be necessary from a memory point of view.

I talked to @michalk8 about this and it's probably a vmap that creates an array of the wrong shape. For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

If the transport matrix fits into CPU memory, then the current best way to go about this is materializing the transport matrix before calling impute:

for key, value in lmp.problems.items():
    value.solution.to(device="cpu")
    value.set_solution(np.array(value.solution.transport_matrix), overwrite=True)

That prevents the memory issue.

Traceback:

2024-07-05 10:45:20.572529: W external[/tsl/tsl/framework/bfc_allocator.cc:485](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=484)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.76TiB (rounded to 1931211837440)requested by op 
2024-07-05 10:45:20.572824: W external[/tsl/tsl/framework/bfc_allocator.cc:497](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=496)] *****_______________________________________________________________________________________________
2024-07-05 10:45:20.572951: E external[/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732](http://localhost:53807/xla/xla/pjrt/pjrt_stream_executor_client.cc#line=2731)] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1931211837328 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  929.12MiB
              constant allocation:         0B
        maybe_live_out allocation:    1.76TiB
     preallocated temp allocation:         0B
                 total allocation:    1.76TiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.76TiB
		Operator: op_name="jit(_where)[/jit](http://localhost:53807/jit)(main)[/select_n](http://localhost:53807/select_n)" source_file="[/cluster/project/treutlein/USERS/mlange/github/moscot-fork/src/moscot/backends/ott/output.py](http://localhost:53807/lab/tree/github/spatial_analysis/analysis/experiments_and_tutorials/github/moscot-fork/src/moscot/backends/ott/output.py)" source_line=177
		XLA Label: fusion
		Shape: f32[2039,17806,13298]
		==========================

	Buffer 2:
		Size: 903.26MiB
		Entry Parameter Subshape: f32[17806,13298]
		==========================

	Buffer 3:
		Size: 25.86MiB
		Entry Parameter Subshape: pred[2039,1,13298]
		==========================

	Buffer 4:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================
@giovp
Copy link
Member

giovp commented Jul 5, 2024

hi @Marius1311 , yes I observed this as well multiple times and reported it in private to @michalk8 as well

Note that passing a batch size does not help much - let's say I'm passing batch_size=500, then this would still request an array of shape 2039 x 500 x 13298, which still requires over 50GB of memory. Also, this this slows down solving the actual OT problem, which would not be necessary from a memory point of view.

I talked to @michalk8 about this and it's probably a vmap that creates an array of the wrong shape. For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

and yes, I also think that this is due to vmap. I think this is true also for GW problems and also not only for imputation but also e.g. for cell transition in my experience. Basically anywhere you want to apply the transport matrix

For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

this is a solution but would require considerable amount of work as there are various mixin methods that use that operation

@Marius1311
Copy link
Collaborator Author

yes, I agree with you @giovp, batch-wise evaluation isn't really the way to go, this can only be a temporary fix. For me personally, materializing the transport matrix before calling .pull is the best solution, as long as the matrix fits into memory.

@giovp
Copy link
Member

giovp commented Sep 5, 2024

pinging @michalk8 , I think we reported this personally on slack several other times, basically any time we do a push/pull of the tmap, it seems that there is an unusual amount of memory being used. afaik @Marius1311 mentioned that you were working on some refactoring of the geometries in ott, and wondering if there is any update?

@MUCDK MUCDK mentioned this issue Oct 21, 2024
2 tasks
@MUCDK
Copy link
Collaborator

MUCDK commented Dec 6, 2024

@selmanozleyen let's please ensure that this is not a problem any more after update to ott-jax 0.5.0

@MUCDK MUCDK assigned selmanozleyen and unassigned giovp, michalk8 and MUCDK Dec 6, 2024
@selmanozleyen
Copy link
Collaborator

@MUCDK I ran a small benchmark with the code I wrote last year #639 (comment). This says the peak memory seems to stabilize after some point.

Since this was a very small and quick benchmark on my local system I'd say it's not entirely clear to me that ott-jax has solved this issue on their side. So if this issue remains we should look at the ott-jax side first again. Let's see whats the feedback after the update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants