-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Idea for Discussion: Device Memory Management Primitives #780
Comments
How much of this is an implementation issue (i.e. implementations can be cleverer than they currently are about memory management when multiple large graphs are active) vs. something that needs to be exposed to developers via the API? |
I think model swapping between GPU/Main Memory is feasible in a clever implementation (a LRU cache of somesort). I'm not sure how much overhead that will be, or if it will make it harder for web applications to get predictable performance characteristics (what if the LRU cache isn't clever enough to adapt to the workload). I think the "stream from disk" / mmap most likely require some API change (I don't think ArrayBuffers are mmap-ed. |
Once a graph is built it isn't backed by an ArrayBuffers, it is opaquely held behind the Similarly, we've been working on changes to the implementation of |
So reading into ArrayBuffer is still required? Some model files can be >40GB in size, which won't fit in main memory. So the graph building stage will fail (because building requires an ArrayBuffer in memory). |
With the changes we've made the ArrayBuffers passed to constant() do not need to be held in memory until build() is called. They could be written to disk incrementally. They only need to exist in memory for the constant() call. While a model file can be >4GB in size an individual constant is typically much smaller because it only contains the weights for a single layer of the model. |
The idea is perhaps future-looking, but I'd like to bring it up for discussion.
Motivations
Real World Examples
Case 1: Text to Image
Text to Image use cases generally involves multiple models, represented as a three stage pipeline (consisting of at least 3 models).
Take FLUX to generate an image of 1024x1024 for example:
In the ideal case, all three stages fits in GPU memory (totaling > 32GB). This exceeds the limit of every consumer GPU (except Apple M series with a large unified memory).
The only practical way to run FLUX is to "load, compute, unload" each model into GPU in sequence, at the cost of "reinitialize" each stage for each text2image inference.
This reduces the required GPU memory
max(required_memory_per_stage)
fromsum(required_memory_per_stage)
, and requires the main memory to fitsum(size_of_model_weight)
.Note:
Case 2: Mixture of Experts (MoE) model
MoE models are self-explanatory. They used multiple small models to produce one inference result (e.g. the next token in LLMs).
Only one small model (say 1/8 size of the entire model) need to reside in GPU memory at a time. Each small model sequentially computes their result, then the results are merged into a single output token.
A high level pseudo code:
If the GPU has enough memory, all of the small models can reside in memory (load and unload becomes no-op). If not, the small models are repeatedly load to and unload from gpu (usually to/from main memory).
Some LLMs adopt an architecture where the number of activated params (model weight that has to reside in GPU memory) is much smaller than the number of total params (total size of model weight). I believe they functions similarly to MoE at inference time from memory usage standpoint.
Examples:
Case 3: Model Streaming
I observe this technique when playing with Ollama.
If the model weight is too big to fit into main memory, the model weight are streamed from disk during inference (e.g. the model will be read from disk N times for N predicted tokens).
It's slow (bottlenecked by disk throughput), but does allow large models inference to happen.
Current WebNN API
Case 1 and 2 are feasible but not efficient. Model pipelining and partitioning involves destroying and rebuilding the compute graph.
For every inference (e.g. generate one image from a text prompt, predict one token using MoE model), we need to:
Case 3 is infeasible because the entire model weight needs to be copied to WebNN service process before we can build a graph. We can fallback to model partitioning, and convert the problem to case 2.
API / Spec Implication
The topic for discussion :)
Two primitives?
model.to(device)
)mmap
equivalent?)Related spec / APIs I can think of:
FileSystemFileHandle
)?The text was updated successfully, but these errors were encountered: