Skip to content

Commit

Permalink
Add documentation about how to do serialization in a separate process
Browse files Browse the repository at this point in the history
  • Loading branch information
bchess committed Apr 30, 2024
1 parent 3b4fb74 commit dc66464
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 0 deletions.
186 changes: 186 additions & 0 deletions README-subprocess-serialization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Tensorizer serialization via subprocess

If you're using Tensorizer serialization to write checkpoints during training,
you may want to run the serialization concurrently from your training code so
that you can execute your next training step as quickly as possible. And
because of the Python GIL, it's better to do this in a separate process so that the
serialization doesn't utilize any of the GIL that you'd otherwise use in your training code.

Keep in mind that this is a way to achieve _concurrency_, not instant
snapshotting. The tensors you are checkpointing still need to be kept in memory,
unmodified, for the duration of the serialization process. (Though you may
choose to copy them out of CUDA memory into CPU memory. These tradeoffs are
discussed below.)

Also refer to [PyTorch Multiprocessing best
practices](https://pytorch.org/docs/stable/notes/multiprocessing.html) for more
details about using PyTorch across processes


## Warning about fork() and threads
Be aware that Python `os.fork()` is often not a viable option, as it can be known to cause deadlocks if you have multiple threads. Python 3.12 and above
will [issue a deprecation warning](https://github.com/python/cpython/pull/100229) if you attempt this. Some 3rd-party packages that rely on sockets or file descriptors may also not behave correctly when a process unexpectedly forks.

`subprocess` is generally safer, but you do not inherently get shared memory with the calling process.

## If starting from CUDA
Presuming your tensors are in CUDA memory, there are a couple different options.

### Option 1: Communicate the CUDA tensors directly
CUDA tensors can be "shared" to a subprocess very efficiently since it's only communicating a pointer to device memory.

Basically send the CUDA tensors over a `multiprocessing.Queue` to a subprocess that does the serialization. Ensure that the CUDA tensors remain in device memory until the serialization process finishes.

```python
import torch
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM
import torch.multiprocessing as mp

def do_serialize(uri: str, model: torch.nn.Module):
serializer = TensorSerializer(uri)
serializer.write_module(model)
serializer.close()

def my_gpu_model() -> torch.nn.Module:
model_ref = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
model_ref,
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model.to('cuda')
return model

def main():
dest = "gpt-j-6B.tensors"
model = my_gpu_model()

mp.set_start_method('spawn')
p = mp.Process(target=do_serialize, args=(dest, model))
p.start()

# main process is now free to do other stuff but `model` must remain in CUDA
# memory until the `p` subprocess finishes

p.join()


if __name__ == '__main__':
main()
```

### Option 2: Snapshot CUDA tensors to CPU memory in subprocess before serializing

Once the tensors are in CPU memory, they no longer need to occupy CUDA memory. But the tensors
will now need to occupy CPU memory until they are fully serialized.

Do this by calling `model.to('cpu')` immediately after sending to serializer.

If you like, you can also use some sort of IPC object to communicate back to the
host process when the snapshotting has finished so you know when the CUDA memory
can be released. The below code uses a `Queue`

```python
import torch
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM
import torch.multiprocessing as mp

def do_serialize(uri: str, model: pytorch.nn.Module, snapshot_done: mp.Queue):
model = model.to('cpu') # Snapshot now
snapshot_done.put(True)

serializer = TensorSerializer(uri)
serializer.write_module(model)
serializer.close()

def my_gpu_model() -> torch.nn.Module:
model_ref = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
model_ref,
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model.to('cuda')
return model

def main():
dest = "gpt-j-6B.tensors"
model = my_gpu_model()

mp.set_start_method('spawn')
snapshot_done = mp.Queue()
p = mp.Process(target=do_serialize, args=(dest, model, snapshot_done))
p.start()

# main process is now free to do other stuff
# but `model` must remain in CUDA memory

snapshot_done.get()
# Subprocess copied model into CPU memory. Free to release the CUDA-based model
del model

# ... do other stuff ...

if not p.is_alive():
print('Serialization finished.')

p.join()


if __name__ == '__main__':
main()
```

## If starting from CPU memory

Tensors in CPU memory need to moved to shared memory to be communicated with a subprocess. PyTorch `muliptrocessing` will do this transparently, but be aware
that a memcpy occurs. You'll also need additional "surge" CPU memory during the duration of the copy of each tensor.

Depending on how you are constructing your CPU tensor, you may be able to preemptively `tensor.share_memory()` ahead of time, thus saving a memcpy when
passing to the subprocess.

```python
import torch
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM
import torch.multiprocessing as mp

def do_serialize(uri: str, model: torch.nn.Module):
serializer = TensorSerializer(uri)
serializer.write_module(model)
serializer.close()

def my_gpu_model() -> torch.nn.Module:
model_ref = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
model_ref,
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
return model

def main():
dest = "gpt-j-6B.tensors"
model = my_gpu_model()

mp.set_start_method('spawn')

# this will execute model.share_memory()
p = mp.Process(target=do_serialize, args=(dest, model))

p.start()

# main process is now free to do other stuff
# but `model` must remain in CPU memory until the `p` subprocess finishes

p.join()


if __name__ == '__main__':
main()
```
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,7 @@ python -m pip install -e .
python -m pip install -r tests/requirements.txt
python -m unittest discover tests/ --verbose
```

## Serialization in a subprocess
You may want to do Serialization in a separate process so that your main process can continue executing and not get bogged down by GIL contention.
See [README-subprocess-serialization.md](README-subprocess-serialization.md) for more details.

0 comments on commit dc66464

Please sign in to comment.