Skip to content

Commit

Permalink
[cache-memory-leak] Fix Memory leak in cache server (#416)
Browse files Browse the repository at this point in the history
- Recreate the multiprocess pool at a regular cadence to avoid memory leaks
- Since the pool was never removed it resulted in uncleared memory.
- Add a log size constraint to the cache server to avoid memory leaks.
- fix test too
  • Loading branch information
valayDave authored Feb 19, 2024
1 parent 79e17e3 commit 707c534
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def read_message(self, line: str):

if self.logger.isEnabledFor(logging.INFO):
self.logger.info(
"Pending stream keys: {}".format(list(self.pending_requests))
"Pending stream keys: {}".format(len(list(self.pending_requests)))
)
except JSONDecodeError as ex:
if self.logger.isEnabledFor(logging.INFO):
Expand All @@ -73,6 +73,7 @@ async def stop_server(self):
if self._is_alive:
self._is_alive = False
self._proc.terminate()
self.logger.info("Waiting for cache server to terminate")
await self._proc.wait()

async def send_request(self, blob):
Expand Down
57 changes: 57 additions & 0 deletions services/ui_backend_service/data/cache/client/cache_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from datetime import datetime
from collections import deque
from itertools import chain
import time

from .cache_worker import execute_action
from .cache_async_client import OP_WORKER_CREATE, OP_WORKER_TERMINATE

import sys

import click

from .cache_action import CacheAction, \
Expand All @@ -32,6 +35,10 @@ def send_message(op: str, data: dict):
}), flush=True)


CACHE_PROCESS_POOL_REFRESH_DURATION = int(os.environ.get("CACHE_PROCESS_POOL_REFRESH_DURATION", 20 * 60))
CACHE_PROCESS_POOL_FORCE_REFRESH_DURATION = int(os.environ.get("CACHE_PROCESS_POOL_FORCE_REFRESH_DURATION", 2 * 60))


class CacheServerException(Exception):
pass

Expand Down Expand Up @@ -101,6 +108,7 @@ def __init__(self, request, filestore, pool, callback=None, error_callback=None)
self.pool = pool
self.callback = callback
self.error_callback = error_callback
self.created_on = time.time()

try:
self.tempdir = self.filestore.open_tempdir(
Expand Down Expand Up @@ -198,6 +206,7 @@ def __init__(self, filestore, max_workers):
initializer=self.init_process,
maxtasksperchild=512, # Recycle each worker once 512 tasks have been completed
)
self._pool_started_on = time.time()

def init_process(self):
echo("Init process %s pid: %s" % (multiprocessing.current_process().name, os.getpid()))
Expand Down Expand Up @@ -257,16 +266,64 @@ def queued_request(queue):
send_message(OP_WORKER_TERMINATE, worker._worker_details())
return None

def verify_stale_workers(self):
time_to_pool_refresh = CACHE_PROCESS_POOL_REFRESH_DURATION - (time.time() - self._pool_started_on)
# active_pids = ",".join([f"{c.name}[{c.pid}]" for c in multiprocessing.active_children()])
echo(
"number of workers: %d, number of pending requests: %d; Pool Refresh in : %d" % (
len(self.workers), len(self.pending_requests), time_to_pool_refresh)
)

def cleanup_if_necessary(self):
time_to_pool_refresh = CACHE_PROCESS_POOL_REFRESH_DURATION - (time.time() - self._pool_started_on)
if time_to_pool_refresh > 0:
return
# if workers are still running 30 seconds after the pool refresh timeout, then cleanup
no_workers_are_running = len(self.workers) == 0 and len(self.pending_requests) == 1
pool_needs_refresh = time_to_pool_refresh <= 0
pool_force_refresh = time_to_pool_refresh < - CACHE_PROCESS_POOL_FORCE_REFRESH_DURATION
if pool_force_refresh:
echo("Refreshing the pool as no workers are running and no pending requests are there.")
self.cleanup_workers()
elif no_workers_are_running and pool_needs_refresh:
echo("Refreshing the pool as no workers are running and no pending requests are there.")
self.cleanup_workers()

def cleanup_workers(self):
for worker in self.workers:
worker.echo("Terminating worker")
worker.terminate()
self.pending_requests.remove(worker.request['idempotency_token'])
self.workers = []
self.cleanup_pool()

def cleanup_pool(self):
self.pool.terminate()
self.pool.join()
del self.pool
self.pool = multiprocessing.Pool(
processes=self.max_workers,
initializer=self.init_process,
maxtasksperchild=512, # Recycle each worker once 512 tasks have been completed
)
self._pool_started_on = time.time()

def loop(self):
def new_worker_from_request():
worker = self.schedule()
if worker:
self.workers.append(worker)
return worker
_counter = time.time()

while True:
self.process_incoming_request()
new_worker_from_request()
if time.time() - _counter > 30:
self.verify_stale_workers()
_counter = time.time()

self.cleanup_if_necessary()
time.sleep(0.1)

def _callback(self, worker, res):
Expand Down
31 changes: 28 additions & 3 deletions services/ui_backend_service/data/cache/get_log_file_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,35 @@ def get_log_provider():
if log_file_policy == 'full':
return FullLogProvider()
elif log_file_policy == 'tail':
# MF_LOG_LOAD_MAX_SIZE: `kilobytes`
max_log_size = int(os.environ.get('MF_LOG_LOAD_MAX_SIZE', 20 * 1024))
# In number of characters (UTF-8)
tail_max_size = int(os.environ.get('MF_LOG_LOAD_TAIL_SIZE', 100 * 1024))
return TailLogProvider(tail_max_size=tail_max_size)
return TailLogProvider(tail_max_size=tail_max_size, max_log_size_in_kb=max_log_size)
elif log_file_policy == 'blurb_only':
return BlurbOnlyLogProvider()
else:
raise ValueError("Unknown log value for MF_LOG_LOAD_POLICY (%s). "
"Must be 'full', 'tail', or 'blurb_only'" % log_file_policy)


def log_size_exceeded_blurb(task: Task, logtype: str, max_size: int):
stream_name = 'stderr' if logtype == STDERR else 'stdout'
blurb = f"""# The size of the log is greater than {int(max_size/1024)}MB which makes it unavailable for viewing on the browser.
Here is a code snippet to get logs using the Metaflow client library:
```
from metaflow import Task, namespace
namespace(None)
task = Task("{task.pathspec}", attempt={task.current_attempt})
{stream_name} = task.{stream_name}
```
# Please visit https://docs.metaflow.org/api/client for detailed documentation."""
return blurb


def get_log_size(task: Task, logtype: str):
return task.stderr_size if logtype == STDERR else task.stdout_size

Expand Down Expand Up @@ -206,16 +225,22 @@ def get_log_content(self, task: Task, logtype: str):


class TailLogProvider(LogProviderBase):
def __init__(self, tail_max_size: int):
def __init__(self, tail_max_size: int, max_log_size_in_kb: int):
super().__init__()
self._tail_max_size = tail_max_size
self._max_log_size_in_kb = max_log_size_in_kb

def get_log_hash(self, task: Task, logtype: str) -> int:
# We can still use the true log size as a hash - still valid way to detect log growth
return get_log_size(task, logtype)

def get_log_content(self, task: Task, logtype: str):

log_size_in_bytes = get_log_size(task, logtype)
log_size = log_size_in_bytes / 1024
if log_size > self._max_log_size_in_kb:
return [(
None, log_size_exceeded_blurb(task, logtype, self._max_log_size_in_kb)
), ]
# Note this is inefficient - we will load a 1GB log even if we only want last 100 bytes.
# Doing this efficiently is a step change in complexity and effort - we can do it when justified in future.
raw_content = get_log_content(task, logtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,26 +218,30 @@ def test_tail_log_provider(m_get_log_size, m_get_log_content):
for case in [
{
"max_tail_chars": 10000,
"max_total_size": 200*1024,
"expected_tail_lines": 1000,
"expect_to_truncate": False,
},
{
"max_tail_chars": 1000,
"max_total_size": 200*1024,
"expected_tail_lines": 333,
"expect_to_truncate": True,
},
{
"max_tail_chars": 100,
"max_total_size": 200*1024,
"expected_tail_lines": 33,
"expect_to_truncate": True,
},
{
"max_tail_chars": 0,
"max_total_size": 200*1024,
"expected_tail_lines": 0,
"expect_to_truncate": True,
},
]:
provider = TailLogProvider(case["max_tail_chars"])
provider = TailLogProvider(case["max_tail_chars"], max_log_size_in_kb=case["max_total_size"])
# Log size should still report full log size (even if only partial content returned)
assert provider.get_log_hash(mock_task, STDOUT) == mock_log_size
tail_log_content = provider.get_log_content(mock_task, STDOUT)
Expand Down

0 comments on commit 707c534

Please sign in to comment.