Skip to content

Commit

Permalink
add TTL cache
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Sep 19, 2024
1 parent fb5a619 commit b82c001
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 152 deletions.
275 changes: 168 additions & 107 deletions bioimageio_colab/register_sam_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
import io
import os
from functools import partial
from logging import getLogger
from typing import Union

import numpy as np
import requests
import torch
from cachetools import TTLCache
from dotenv import find_dotenv, load_dotenv
from hypha_rpc import connect_to_server
from kaibu_utils import mask_to_features
Expand All @@ -16,53 +18,57 @@
if ENV_FILE:
load_dotenv(ENV_FILE)

WORKSPACE_TOKEN = os.environ.get("WORKSPACE_TOKEN")
MODELS = {
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt",
"vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
}
STORAGE = {}

if not WORKSPACE_TOKEN:
raise ValueError("Workspace token is required to connect to the Hypha server.")

logger = getLogger(__name__)
logger.setLevel("INFO")


def _load_model(model_name: str) -> torch.nn.Module:
def _load_model(
model_cache: TTLCache, model_name: str, user_id: str
) -> torch.nn.Module:
if model_name not in MODELS:
raise ValueError(
f"Model {model_name} not found. Available models: {list(MODELS.keys())}"
)

model_url = MODELS[model_name]

# Check cache first
if model_url in STORAGE:
logger.info(f"Loading model {model_name} with ID '{model_url}' from cache...")
return STORAGE[model_url]

# Download model if not in cache
logger.info(f"Loading model {model_name} from {model_url}...")
response = requests.get(model_url)
if response.status_code != 200:
raise RuntimeError(f"Failed to download model from {model_url}")
buffer = io.BytesIO(response.content)

# Load model state
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(buffer, map_location=device)
model_type = model_name[:5]
sam = sam_model_registry[model_type]()
sam.load_state_dict(ckpt)
sam = model_cache.get(model_name, None)
if sam:
logger.info(
f"User {user_id} - Loading model '{model_name}' from cache (device={sam.device})..."
)
else:
# Download model if not in cache
model_url = MODELS[model_name]
logger.info(
f"User {user_id} - Loading model '{model_name}' from {model_url}..."
)
response = requests.get(model_url)
if response.status_code != 200:
raise RuntimeError(f"Failed to download model from {model_url}")
buffer = io.BytesIO(response.content)

# Load model state
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(buffer, map_location=device)
model_type = model_name[:5]
sam = sam_model_registry[model_type]()
sam.load_state_dict(ckpt)
logger.info(
f"User {user_id} - Caching model '{model_name}' (device={device})..."
)

# Cache the model
logger.info(f"Caching model {model_name} (device={device}) with ID '{model_url}'...")
STORAGE[model_url] = sam
# Cache the model / renew the TTL
model_cache[model_name] = sam

return sam
# Create a SAM predictor
sam_predictor = SamPredictor(sam)
return sam_predictor


def _to_image(input_: np.ndarray) -> np.ndarray:
Expand All @@ -84,67 +90,59 @@ def _to_image(input_: np.ndarray) -> np.ndarray:
return image


def compute_embedding(model_name: str, image: np.ndarray, context: dict = None) -> bool:
user_id = context["user"].get("id")
if not user_id:
logger.info("User ID not found in context.")
return False
sam = _load_model(model_name)
logger.info(f"User {user_id} - computing embedding of model {model_name}...")
predictor = SamPredictor(sam)
predictor.set_image(_to_image(image))
# Save computed predictor values
logger.info(f"User {user_id} - caching embedding of model {model_name}...")
predictor_dict = {
"model_name": model_name,
"original_size": predictor.original_size,
"input_size": predictor.input_size,
"features": predictor.features, # embedding
"is_image_set": predictor.is_image_set,
}
STORAGE[user_id] = predictor_dict
return True


def reset_embedding(context: dict = None) -> bool:
user_id = context["user"].get("id")
if user_id not in STORAGE:
logger.info(f"User {user_id} not found in storage.")
return False
def _calculate_embedding(
embedding_cache: TTLCache,
sam_predictor: SamPredictor,
model_name: str,
image: np.ndarray,
user_id: str,
) -> np.ndarray:
# Calculate the embedding if not cached
predictor_dict = embedding_cache.get(user_id, {})
if predictor_dict and predictor_dict.get("model_name") == model_name:
logger.info(
f"User {user_id} - Loading image embedding from cache (model: '{model_name}')..."
)
for key, value in predictor_dict.items():
if key != "model_name":
setattr(sam_predictor, key, value)
else:
logger.info(f"User {user_id} - resetting embedding...")
STORAGE[user_id].clear()
return True
logger.info(
f"User {user_id} - Computing image embedding (model: '{model_name}')..."
)
sam_predictor.set_image(_to_image(image))
logger.info(
f"User {user_id} - Caching image embedding (model: '{model_name}')..."
)
predictor_dict = {
"model_name": model_name,
"original_size": sam_predictor.original_size,
"input_size": sam_predictor.input_size,
"features": sam_predictor.features, # embedding
"is_image_set": sam_predictor.is_image_set,
}
# Cache the embedding / renew the TTL
embedding_cache[user_id] = predictor_dict

return sam_predictor

def segment(

def _segment_image(
sam_predictor: SamPredictor,
model_name: str,
point_coordinates: Union[list, np.ndarray],
point_labels: Union[list, np.ndarray],
context: dict = None,
) -> list:
user_id = context["user"].get("id")
if user_id not in STORAGE:
logger.info(f"User {user_id} not found in storage.")
return []

logger.info(
f"User {user_id} - segmenting with model {STORAGE[user_id].get('model_name')}..."
)
# Load the model with the pre-computed embedding
sam = _load_model(STORAGE[user_id].get("model_name"))
predictor = SamPredictor(sam)
for key, value in STORAGE[user_id].items():
if key != "model_name":
setattr(predictor, key, value)
# Run the segmentation
logger.debug(
f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}"
)
user_id: str,
):
if isinstance(point_coordinates, list):
point_coordinates = np.array(point_coordinates, dtype=np.float32)
if isinstance(point_labels, list):
point_labels = np.array(point_labels, dtype=np.float32)
mask, scores, logits = predictor.predict(
logger.debug(
f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}"
)
logger.info(f"User {user_id} - Segmenting image (model: '{model_name}')...")
mask, scores, logits = sam_predictor.predict(
point_coords=point_coordinates[:, ::-1], # SAM has reversed XY conventions
point_labels=point_labels,
multimask_output=False,
Expand All @@ -154,27 +152,65 @@ def segment(
return features


def remove_user_id(context: dict = None) -> bool:
def segment(
model_cache: TTLCache,
embedding_cache: TTLCache,
model_name: str,
image: np.ndarray,
point_coordinates: Union[list, np.ndarray],
point_labels: Union[list, np.ndarray],
context: dict = None,
) -> list:
user_id = context["user"].get("id")
if not user_id:
logger.info("User ID not found in context.")
return False

# Load the model
sam_predictor = _load_model(model_cache, model_name, user_id)

# Calculate the embedding
sam_predictor = _calculate_embedding(
embedding_cache, sam_predictor, model_name, image, user_id
)

# Segment the image
features = _segment_image(
sam_predictor, model_name, point_coordinates, point_labels, user_id
)

return features


def clear_cache(embedding_cache: TTLCache, context: dict = None) -> bool:
user_id = context["user"].get("id")
if user_id not in STORAGE:
logger.info(f"User {user_id} not found in storage.")
if user_id not in embedding_cache:
logger.info(f"User {user_id} - User not found in cache.")
return False
else:
logger.info(f"User {user_id} - removing user from storage...")
del STORAGE[user_id]
logger.info(f"User {user_id} - Resetting embedding cache...")
del embedding_cache[user_id]
return True


def hello(context: dict = None) -> str:
return "Welcome to the Interactive Segmentation service!"


async def register_service(args: dict) -> None:
"""
Register the SAM annotation service on the BioImageIO Colab workspace.
"""
workspace_token = args.token or os.environ.get("WORKSPACE_TOKEN")
if not workspace_token:
raise ValueError("Workspace token is required to connect to the Hypha server.")

# Wait until the client ID is available
test_client = await connect_to_server(
{
"server_url": args.server_url,
"workspace": args.workspace_name,
"token": WORKSPACE_TOKEN,
"token": workspace_token,
}
)
colab_client_id = f"{args.workspace_name}/{args.client_id}"
Expand All @@ -192,11 +228,15 @@ async def register_service(args: dict) -> None:
"server_url": args.server_url,
"workspace": args.workspace_name,
"client_id": args.client_id,
"name": "Model Server",
"token": WORKSPACE_TOKEN,
"name": "SAM Server",
"token": workspace_token,
}
)

# Initialize caches
model_cache = TTLCache(maxsize=len(MODELS), ttl=args.model_timeout)
embedding_cache = TTLCache(maxsize=np.inf, ttl=args.embedding_timeout)

# Register a new service
service_info = await colab_client.register_service(
{
Expand All @@ -208,26 +248,30 @@ async def register_service(args: dict) -> None:
"run_in_executor": True,
},
# Exposed functions:
# compute the image embeddings:
# pass the model-name and the image to compute the embeddings on
# calls load_model internally
# returns True if the embeddings were computed successfully
"compute_embedding": compute_embedding,
# run interactive segmentation
# pass the point coordinates and labels
# returns the predicted mask encoded as geo json
"segment": segment,
# reset the embedding for the user
# returns True if the embedding was removed successfully
"reset_embedding": reset_embedding,
# remove the user id from the storage
# returns True if the user was removed successfully
"remove_user_id": remove_user_id, # TODO: add a timeout to remove a user after a certain time
}, {"overwrite": True}
"hello": hello,
# **Run segmentation**
# Params:
# - model name
# - image to compute the embeddings on
# - point coordinates (XY format)
# - point labels
# Returns:
# - a list of XY coordinates of the segmented polygon in the format (1, N, 2)
"segment": partial(segment, model_cache, embedding_cache),
# **Clear the embedding cache**
# Returns:
# - True if the embedding was removed successfully
# - False if the user was not found in the cache
"clear_cache": partial(clear_cache, embedding_cache),
},
{"overwrite": True},
)
sid = service_info["id"]
assert sid == f"{args.workspace_name}/{args.client_id}:{args.service_id}"
logger.info(f"Registered service with ID: {sid}")
logger.info(
f"Test the service here: {args.server_url}/{args.workspace_name}/services/{args.client_id}:{args.service_id}/hello"
)


if __name__ == "__main__":
Expand All @@ -246,14 +290,31 @@ async def register_service(args: dict) -> None:
)
parser.add_argument(
"--client_id",
default="model-server",
default="kubernetes",
help="Client ID for registering the service",
)
parser.add_argument(
"--service_id",
default="interactive-segmentation",
default="sam",
help="Service ID for registering the service",
)
parser.add_argument(
"--token",
default=None,
help="Workspace token for connecting to the Hypha server",
)
parser.add_argument(
"--model_timeout",
type=int,
default=9600, # 3 hours
help="Model cache timeout in seconds",
)
parser.add_argument(
"--embedding_timeout",
type=int,
default=600, # 10 minutes
help="Embedding cache timeout in seconds",
)
args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
Loading

0 comments on commit b82c001

Please sign in to comment.