Skip to content

Commit

Permalink
fix(pytorch_poc): memory cal (#606)
Browse files Browse the repository at this point in the history
* fix(pytorch_poc): memory cal
  • Loading branch information
tpoisonooo authored Oct 30, 2023
1 parent 2ba8867 commit ef11f5a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 7 deletions.
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
args: ["--exclude=lmdeploy/turbomind/triton_models/*"]
args: ["--exclude=lmdeploy/turbomind/triton_models/*", "--max-line-length=79"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand All @@ -12,6 +12,12 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
name: yapf
description: 'Formatter for Python code'
entry: yapf
language: python
args: ['-i', '--style={based_on_style: pep8, column_limit: 79}']

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
hooks:
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch_poc/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def allocate_gpu_cache(self):
device='cuda',
)
gpu_cache.append((key_blocks, value_blocks))

return gpu_cache

def allocate_cpu_cache(self):
Expand Down Expand Up @@ -189,6 +190,7 @@ def get_cache_block_size(block_size: int,
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)

dtype_size = _get_dtype_size(model_config.dtype)
return dtype_size * total

Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def _update_cache_config(model_config: ModelConfig,
"""
GPU_MEM_PERCENT = 0.7
SWAP_SPACE = 4 * (1 << 30)
reserved_mem = torch.cuda.memory_reserved(gpu_id)
gpu_mem = (get_gpu_memory(gpu_id) - reserved_mem) * GPU_MEM_PERCENT
gpu_mem_physical_free, _ = get_gpu_memory(gpu_id)
gpu_mem = gpu_mem_physical_free * GPU_MEM_PERCENT
cpu_mem = SWAP_SPACE
cache_block_size = CacheEngine.get_cache_block_size(
cache_config.block_size, model_config)
Expand Down Expand Up @@ -1216,6 +1216,7 @@ def end(self, session_id: int):

def cancel(self, session_id: int):
"""Stop current streaming inference."""

self._send_req(RequestType.STOP_SESSION, dict(session_id=session_id))

def decode(self, prompt_token_ids: List[List[int]]):
Expand Down
14 changes: 10 additions & 4 deletions lmdeploy/pytorch_poc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
from typing import Dict, Sequence

import psutil
import torch
import pycuda.driver as drv


def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory
def get_gpu_memory(id: int = 0) -> int:
"""Returns the free and total physical memory of the GPU in bytes."""
drv.init()
dev = drv.Device(id)
cxt = dev.make_context()
free, total = drv.mem_get_info()
cxt.pop()
return free, total


def get_cpu_memory() -> int:
Expand All @@ -21,6 +26,7 @@ def get_cpu_memory() -> int:
def bind_sigature(input_names: str, args: Sequence, kwargs: Dict):
"""Bind args and kwargs to given input names."""
kind = inspect._ParameterKind.POSITIONAL_OR_KEYWORD

sig = Signature([Parameter(name, kind) for name in input_names])
bind = sig.bind(*args, **kwargs)
return bind.arguments
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ gradio
mmengine
numpy
pybind11

pycuda
safetensors
sentencepiece
setuptools
Expand Down

0 comments on commit ef11f5a

Please sign in to comment.