Skip to content

Commit

Permalink
Redis caching experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
sigurdp committed Nov 22, 2024
1 parent e209c3e commit a0c73df
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions backend_py/primary/primary/routers/surface/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@
import logging
from typing import Annotated, List, Optional, Literal

from fastapi import (
APIRouter,
Depends,
HTTPException,
Query,
Response,
Request,
Body,
status,
)
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Request, Body, status
from webviz_pkg.core_utils.perf_metrics import PerfMetrics

from primary.services.sumo_access.case_inspector import CaseInspector
Expand Down Expand Up @@ -51,12 +42,34 @@
)
from .surface_address import decode_surf_addr_str

import redis
from primary import config
from fastapi_cache import FastAPICache
from fastapi import FastAPI
from typing import AsyncIterator
from fastapi_cache.backends.redis import RedisBackend
from contextlib import asynccontextmanager
from fastapi_cache.decorator import cache
import json
from pydantic import TypeAdapter
import hashlib

LOGGER = logging.getLogger(__name__)

router = APIRouter()


# @asynccontextmanager
# async def lifespan(_: FastAPI) -> AsyncIterator[None]:
# #pool = ConnectionPool.from_url(url="redis://redis")
# redis_client = redis.Redis.from_url(config.REDIS_CACHE_URL, decode_responses=False)
# FastAPICache.init(RedisBackend(redis_client), prefix="fastapi-cache")
# yield

redis_client = redis.Redis.from_url(config.REDIS_CACHE_URL, decode_responses=True)
#FastAPICache.init(RedisBackend(redis_client), prefix="")


GENERAL_SURF_ADDR_DOC_STR = """
---
Expand All @@ -83,6 +96,7 @@


@router.get("/realization_surfaces_metadata/")
#@cache(expire=600)
async def get_realization_surfaces_metadata(
response: Response,
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
Expand Down Expand Up @@ -319,6 +333,7 @@ async def post_get_surface_intersection(


@router.post("/sample_surface_in_points")
#@cache(expire=600)
async def post_sample_surface_in_points(
request: Request,
case_uuid: str = Query(description="Sumo case uuid"),
Expand All @@ -329,15 +344,38 @@ async def post_sample_surface_in_points(
sample_points: schemas.PointSetXY = Body(embed=True),
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
) -> List[schemas.SurfaceRealizationSampleValues]:

perf_metrics = PerfMetrics()

ta = TypeAdapter(list[schemas.SurfaceRealizationSampleValues])

# compute hash of points
def compute_hash(numbers: List[int]) -> str:
hash_object = hashlib.sha256()
for number in numbers:
hash_object.update(str(number).encode('utf-8'))
return hash_object.hexdigest()

points_hash = compute_hash(sample_points.x_points + sample_points.y_points)

my_cache_key = f"{authenticated_user.get_user_id()}::{case_uuid}::{ensemble_name}::{surface_name}::{surface_attribute}::{realization_nums}::{points_hash}"
perf_metrics.record_lap("build-key")


cached_data = redis_client.get(my_cache_key)
perf_metrics.record_lap("cache-lookup")

if cached_data:
LOGGER.info(f"!!!!!!!!!!!!!!!!!!!!Returning CACHED result in: {perf_metrics.to_string()}")
return ta.validate_json(cached_data)

sumo_access_token = authenticated_user.get_sumo_access_token()
perf_metrics.record_lap("get-access-token")

async_client = request.app.state.requests_client
perf_metrics.record_lap("get-async-client")

result_arr: List[
RealizationSampleResult
] = await batch_sample_surface_in_points_async(
result_arr: List[RealizationSampleResult] = await batch_sample_surface_in_points_async(
async_client=async_client,
sumo_access_token=sumo_access_token,
case_uuid=case_uuid,
Expand All @@ -358,9 +396,13 @@ async def post_sample_surface_in_points(
sampled_values=res.sampledValues,
)
)
perf_metrics.record_lap("convert to api response")
perf_metrics.record_lap("convert")

redis_client.set(my_cache_key, ta.dump_json(intersections))
perf_metrics.record_lap("write-cache")

LOGGER.info(f"Sampled surface in points in: {perf_metrics.to_string()}")

return intersections


Expand Down

0 comments on commit a0c73df

Please sign in to comment.