Skip to content

Commit

Permalink
finish with prometheus and fastapi contrib modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Dima Kryukov committed Jan 2, 2024
1 parent 2906888 commit 0b437a7
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 90 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ exclude_lines =
@(abc\.)?abstractmethod
@overload

omit =
cashews/_typing.py

[report]
precision = 2
fail_under = 70
Expand Down
89 changes: 89 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ More examples [here](https://github.com/Krukov/cashews/tree/master/examples)
- [Cache invalidation on code change](#cache-invalidation-on-code-change)
- [Detect the source of a result](#detect-the-source-of-a-result)
- [Middleware](#middleware)
- [Callbacks](#callbacks)
- [Transactional mode](#transactional)
- [Contrib](#contrib)
- [Fastapi](#fastapi)
- [Prometheus](#prometheus)

### Configuration

Expand Down Expand Up @@ -891,6 +895,23 @@ async def logging_middleware(call, cmd: Command, backend: Backend, *args, **kwar
cache.setup("mem://", middlewares=(logging_middleware, ))
```

#### Callbacks

One of the middleware that is preinstalled in cache instance is `CallbackMiddleware`.
This middleware also add to a cache a new interface that allow to add a function that will be called before given command will be triggered

```python
from cashews import cache, Command


def callback(key, result):
print(f"GET key={key}")

with cache.callback(callback, cmd=Command.GET):
await cache.get("test") # also will print "GET key=test"

```

### Transactional

Applications are more often based on a database with transaction (OLTP) usage. Usually cache supports transactions poorly.
Expand Down Expand Up @@ -963,6 +984,74 @@ async def my_handler():
...
```

### Contrib

This library is framework agnostic, but includes several "batteries" for most popular tools.

#### Fastapi

You may find a few middlewares useful that can help you to control a cache in you web application based on fastapi.

1. `CacheEtagMiddleware` - middleware add Etag and check 'If-None-Match' header based on Etag
2. `CacheRequestControlMiddleware` - middleware check and add `Cache-Control` header
3. `CacheDeleteMiddleware` - clear cache for an endpoint based on `Clear-Site-Data` header

Example:

```python
from fastapi import FastAPI, Header, Query
from fastapi.responses import StreamingResponse

from cashews import cache
from cashews.contrib.fastapi import (
CacheDeleteMiddleware,
CacheEtagMiddleware,
CacheRequestControlMiddleware,
cache_control_ttl,
)

app = FastAPI()
app.add_middleware(CacheDeleteMiddleware)
app.add_middleware(CacheEtagMiddleware)
app.add_middleware(CacheRequestControlMiddleware)
metrics_middleware = create_metrics_middleware()
cache.setup(os.environ.get("CACHE_URI", "redis://"))



@app.get("/")
@cache.failover(ttl="1h")
@cache(ttl=cache_control_ttl(default="4m"), key="simple:{user_agent:hash}", time_condition="1s")
async def simple(user_agent: str = Header("No")):
...


@app.get("/stream")
@cache(ttl="1m", key="stream:{file_path}")
async def stream(file_path: str = Query(__file__)):
return StreamingResponse(_read_file(file_path=file_path))


async def _read_file(_read_file):
...

```

Also cashews can cache stream responses

#### Prometheus

You can easily provide metrics using the Prometheus middleware.

```python
from cashews import cache
from cashews.contrib.prometheus import create_metrics_middleware

metrics_middleware = create_metrics_middleware(with_tag=False)
cache.setup("redis://", middlewares=(metrics_middleware,))

```

## Development

### Setup
Expand Down
4 changes: 2 additions & 2 deletions cashews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .cache_condition import NOT_NONE, only_exceptions, with_exceptions
from .commands import Command
from .contrib import * # noqa
from .decorators import context_cache_detect, fast_condition, thunder_protection
from .decorators import CacheDetect, context_cache_detect, fast_condition, thunder_protection
from .exceptions import CacheBackendInteractionError, CircuitBreakerOpen, LockedError, RateLimitError
from .formatter import default_formatter, get_template_and_func_for, get_template_for_key
from .helpers import add_prefix, all_keys_lower, memory_limit
Expand All @@ -20,7 +20,7 @@
hit = cache.hit
transaction = cache.transaction
setup = cache.setup
cache_detect: ContextManager = cache.detect
cache_detect: ContextManager[CacheDetect] = cache.detect

circuit_breaker = cache.circuit_breaker
dynamic = cache.dynamic
Expand Down
18 changes: 15 additions & 3 deletions cashews/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(
...


class Callback(Protocol):
class OnRemoveCallback(Protocol):
async def __call__(
self,
keys: Iterable[Key],
Expand All @@ -61,11 +61,23 @@ async def __call__(
...


class Callback(Protocol):
async def __call__(self, cmd: Command, key: Key, result: Any, backend: Backend) -> None:
pass


class ShortCallback(Protocol):
def __call__(self, key: Key, result: Any) -> None:
pass


class ICustomEncoder(Protocol):
async def __call__(self, value: Value, backend, key: Key, expire: float | None) -> bytes: # pragma: no cover
async def __call__(
self, value: Value, backend: Backend, key: Key, expire: float | None
) -> bytes: # pragma: no cover
...


class ICustomDecoder(Protocol):
async def __call__(self, value: bytes, backend, key: Key) -> Value: # pragma: no cover
async def __call__(self, value: bytes, backend: Backend, key: Key) -> Value: # pragma: no cover
...
6 changes: 3 additions & 3 deletions cashews/backends/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cashews.exceptions import CacheBackendInteractionError, LockedError

if TYPE_CHECKING: # pragma: no cover
from cashews._typing import Callback, Default, Key, Value
from cashews._typing import Default, Key, OnRemoveCallback, Value

NOT_EXIST = -2
UNLIMITED = -1
Expand Down Expand Up @@ -245,9 +245,9 @@ def enable(self, *cmds: Command) -> None:
class Backend(ControlMixin, _BackendInterface, metaclass=ABCMeta):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self._on_remove_callbacks: list[Callback] = []
self._on_remove_callbacks: list[OnRemoveCallback] = []

def on_remove_callback(self, callback: Callback) -> None:
def on_remove_callback(self, callback: OnRemoveCallback) -> None:
self._on_remove_callbacks.append(callback)

async def _call_on_remove_callbacks(self, *keys: Key) -> None:
Expand Down
4 changes: 2 additions & 2 deletions cashews/backends/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from uuid import uuid4

from cashews import LockedError
from cashews._typing import Callback, Key, Value
from cashews._typing import Key, OnRemoveCallback, Value
from cashews.backends.interface import NOT_EXIST, UNLIMITED, Backend
from cashews.backends.memory import Memory

Expand Down Expand Up @@ -53,7 +53,7 @@ def _clear_local_storage(self):
self._local_cache = Memory()
self._to_delete = set()

def on_remove_callback(self, callback: Callback):
def on_remove_callback(self, callback: OnRemoveCallback):
self._backend.on_remove_callback(callback)
self._local_cache.on_remove_callback(callback)

Expand Down
6 changes: 3 additions & 3 deletions cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import nullcontext
from contextvars import ContextVar
from hashlib import blake2s
from typing import ContextManager, Sequence
from typing import Any, ContextManager, Sequence

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
Expand Down Expand Up @@ -121,11 +121,11 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):

set_key = None

def set_callback(key, result):
def set_callback(key: str, result: Any):
nonlocal set_key
set_key = key

with self._cache.detect as detector, self._cache.callback(Command.SET, set_callback):
with self._cache.detect as detector, self._cache.callback(set_callback, cmd=Command.SET):
response = await call_next(request)
calls = detector.calls_list
if not calls:
Expand Down
4 changes: 2 additions & 2 deletions cashews/contrib/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def create_metrics_middleware(latency_metric: Optional[Histogram] = None, with_t
_DEFAULT_METRIC = Histogram(
"cashews_operations_latency_seconds",
"Latency of different operations with a cache",
labels=["operation", "backend_class"] if not with_tag else ["operation", "backend_class", "tag"],
labelnames=["operation", "backend_class"] if not with_tag else ["operation", "backend_class", "tag"],
)
_latency_metric = latency_metric or _DEFAULT_METRIC

async def metrics_middleware(call, cmd: Command, backend: Backend, *args, **kwargs):
with _latency_metric as metric:
with _latency_metric.time() as metric:
metric.labels(operation=cmd.value, backend_class=backend.__class__.__name__)
if with_tag and "key" in kwargs:
tags = cache.get_key_tags(kwargs["key"])
Expand Down
26 changes: 14 additions & 12 deletions cashews/wrapper/callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import contextlib
import uuid
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterator

from cashews._typing import AsyncCallable_T
from cashews._typing import AsyncCallable_T, Callback, Key, ShortCallback
from cashews.commands import PATTERN_CMDS, Command
from cashews.key import get_call_values

Expand All @@ -21,18 +21,20 @@ async def __call__(self, call: AsyncCallable_T, cmd: Command, backend: "Backend"
as_key = "pattern" if cmd in PATTERN_CMDS else "key"
call_values = get_call_values(call, args, kwargs)
key = call_values.get(as_key)
if key is None:
return result
for callback in self._callbacks.values():
callback(cmd, key=key, result=result, backend=backend)
await callback(cmd, key=key, result=result, backend=backend)
return result

def add_callback(self, callback, name):
def add_callback(self, callback: Callback, name: str):
self._callbacks[name] = callback

def remove_callback(self, name):
def remove_callback(self, name: str):
del self._callbacks[name]

@contextlib.contextmanager
def callback(self, callback):
def callback(self, callback: Callback) -> Iterator[None]:
name = uuid.uuid4().hex
self.add_callback(callback, name)
try:
Expand All @@ -44,16 +46,16 @@ def callback(self, callback):
class CallbackWrapper(Wrapper):
def __init__(self, name: str = ""):
super().__init__(name)
self._callbacks = CallbackMiddleware()
self.add_middleware(self._callbacks)
self.callbacks = CallbackMiddleware()
self.add_middleware(self.callbacks)

@contextlib.contextmanager
def callback(self, cmd: Command, callback):
def callback(self, callback: ShortCallback, cmd: Command) -> Iterator[None]:
t_cmd = cmd

def _wrapped_callback(cmd, key, result, backend):
async def _wrapped_callback(cmd: Command, key: Key, result: Any, backend: "Backend") -> None:
if cmd == t_cmd:
callback(key, result)
callback(key, result=result)

with self._callbacks.callback(_wrapped_callback):
with self.callbacks.callback(_wrapped_callback):
yield
4 changes: 2 additions & 2 deletions cashews/wrapper/tags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import lru_cache
from typing import Dict, Iterable, List, Match, Optional, Pattern, Tuple

from cashews._typing import TTL, Callback, Key, KeyOrTemplate, Tag, Tags, Value
from cashews._typing import TTL, Key, KeyOrTemplate, OnRemoveCallback, Tag, Tags, Value
from cashews.backends.interface import Backend
from cashews.exceptions import TagNotRegisteredError
from cashews.formatter import template_to_re_pattern
Expand Down Expand Up @@ -69,7 +69,7 @@ def _add_backend(self, backend: Backend, *args, **kwargs):
super()._add_backend(backend, *args, **kwargs)
backend.on_remove_callback(self._on_remove_cb)

def _on_remove_callback(self) -> Callback:
def _on_remove_callback(self) -> OnRemoveCallback:
async def _callback(keys: Iterable[Key], backend: Backend) -> None:
for tag, _keys in self._group_by_tags(keys).items():
await self.tags_backend.set_remove(self._tags_key_prefix + tag, *_keys)
Expand Down
25 changes: 12 additions & 13 deletions examples/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import FastAPI, Header, Query
from fastapi.responses import StreamingResponse
from prometheus_client import make_asgi_app

from cashews import cache
from cashews.contrib.fastapi import (
Expand All @@ -14,38 +15,36 @@
CacheRequestControlMiddleware,
cache_control_ttl,
)
from cashews.contrib.prometheus import create_metrics_middleware

app = FastAPI()
app.add_middleware(CacheDeleteMiddleware)
app.add_middleware(CacheEtagMiddleware)
app.add_middleware(CacheRequestControlMiddleware)
cache.setup(os.environ.get("CACHE_URI", "redis://"))

metrics_middleware = create_metrics_middleware()
cache.setup(os.environ.get("CACHE_URI", "redis://"), middlewares=(metrics_middleware,))
cache.setup("mem://", middlewares=(metrics_middleware,), prefix="srl")
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
KB = 1024


@app.get("/")
@cache.failover(ttl="1h")
@cache.slice_rate_limit(10, "3m")
@cache(ttl=cache_control_ttl(default="4m"), key="simple:{user_agent}", time_condition="1s")
@cache.slice_rate_limit(limit=10, period="3m", key="rate:{user_agent:hash}")
@cache(ttl=cache_control_ttl(default="4m"), key="simple:{user_agent:hash}", time_condition="1s")
async def simple(user_agent: str = Header("No")):
await asyncio.sleep(1.1)
return "".join([random.choice(string.ascii_letters) for _ in range(10)])


@app.get("/stream")
def stream(file_path: str = Query(__file__)):
@cache(ttl="1m", key="stream:{file_path}")
async def stream(file_path: str = Query(__file__)):
return StreamingResponse(_read_file(file_path=file_path))


def size_less(limit: int):
def _condition(chunk, args, kwargs, key):
size = os.path.getsize(kwargs["file_path"])
return size < limit

return _condition


@cache.iterator("2h", key="file:{file_path:hash}", condition=size_less(100 * KB))
async def _read_file(*, file_path, chunk_size=10 * KB):
loop = asyncio.get_running_loop()
with open(file_path, encoding="latin1") as file_obj:
Expand Down
Loading

0 comments on commit 0b437a7

Please sign in to comment.