Skip to content

Commit

Permalink
set capture mode thread_local (#2560)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Oct 21, 2024
1 parent d009335 commit a465e60
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,12 @@ def capture(self, **kwargs):
self.model(**padded_kwargs)

self._graph = torch.cuda.CUDAGraph()
# unsafe kernel call in other thread might invalid the capture
# so we set thread_safe capture mode here.
with torch.cuda.graph(self._graph,
pool=self.pool,
stream=current_stream):
stream=current_stream,
capture_error_mode='thread_local'):
output = self.model(**padded_kwargs)

output_buffers = dict(logits=output)
Expand Down

0 comments on commit a465e60

Please sign in to comment.