Skip to content

Commit

Permalink
chore: Add _clear_model_cache function to clear model cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Sep 21, 2023
1 parent 461179e commit 28c2451
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
4 changes: 2 additions & 2 deletions pilot/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,7 +87,7 @@ def stop(self) -> None:
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)

def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False
Expand Down
4 changes: 2 additions & 2 deletions pilot/model/cluster/worker/embedding_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,7 +79,7 @@ def stop(self) -> None:
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)

def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""
Expand Down
20 changes: 15 additions & 5 deletions pilot/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import logging

logger = logging.getLogger(__name__)

def _clear_torch_cache(device="cuda"):

def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch

_clear_torch_cache(device)
except ImportError:
return
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache


def _clear_torch_cache(device="cuda"):
import torch
import gc

gc.collect()
Expand All @@ -16,14 +26,14 @@ def _clear_torch_cache(device="cuda"):

empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")
logger.info("No cuda or mps, not support clear torch cache yet")

0 comments on commit 28c2451

Please sign in to comment.