Skip to content

Commit

Permalink
Add invalid endpoint (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Oct 23, 2024
1 parent 0d90594 commit 6ee0103
Show file tree
Hide file tree
Showing 11 changed files with 2,097 additions and 1,681 deletions.
100 changes: 100 additions & 0 deletions labs_api/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from datetime import datetime, timedelta
from typing import Optional, TypeVar

from prediction_market_agent_tooling.gtypes import HexAddress
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.utils import utcnow
from pydantic import BaseModel
from sqlmodel import Field, Session, SQLModel, create_engine, desc, select

from labs_api.config import Config


class ResponseCacheModel(BaseModel):
market_id: HexAddress
created_at: datetime


class ResponseCacheSQLModel(SQLModel):
__table_args__ = {"extend_existing": True}
id: Optional[int] = Field(default=None, primary_key=True)
market_id: str = Field(index=True)
datetime_: datetime = Field(index=True)
json_dump: str


from typing import Generic, TypeVar

ResponseCacheModelVar = TypeVar("ResponseCacheModelVar", bound=ResponseCacheModel)
ResponseCacheSQLModelVar = TypeVar(
"ResponseCacheSQLModelVar", bound=ResponseCacheSQLModel
)


class ResponseCache(Generic[ResponseCacheModelVar, ResponseCacheSQLModelVar]):
RESPONSE_CACHE_MODEL: type[ResponseCacheModelVar]
RESPONSE_CACHE_SQL_MODEL: type[ResponseCacheSQLModelVar]

def __init__(
self,
cache_expiry_days: int | None,
sqlalchemy_db_url: str | None = None,
):
self.cache_expiry_days = cache_expiry_days
self.engine = create_engine(
sqlalchemy_db_url
if sqlalchemy_db_url
else Config().sqlalchemy_db_url.get_secret_value()
)
self._initialize_db()

def _initialize_db(self) -> None:
"""
Creates the tables if they don't exist
"""

# trick for making models import mandatory - models must be imported for metadata.create_all to work
logger.debug(f"tables being added {self.RESPONSE_CACHE_SQL_MODEL}")
SQLModel.metadata.create_all(self.engine)

def find(
self,
market_id: HexAddress,
) -> ResponseCacheModelVar | None:
with Session(self.engine) as session:
query = select(self.RESPONSE_CACHE_SQL_MODEL).where(
self.RESPONSE_CACHE_SQL_MODEL.market_id == market_id
)
if self.cache_expiry_days is not None:
query = query.where(
self.RESPONSE_CACHE_SQL_MODEL.datetime_
>= utcnow() - timedelta(days=self.cache_expiry_days)
)
db_item = session.exec(
query.order_by(desc(self.RESPONSE_CACHE_SQL_MODEL.datetime_))
).first()
try:
response = (
self.RESPONSE_CACHE_MODEL.model_validate_json(db_item.json_dump)
if db_item is not None
else None
)
except ValueError as e:
logger.error(
f"Error deserializing {self.RESPONSE_CACHE_MODEL} from cache for {market_id=} and {db_item=}: {e}"
)
response = None
return response

def save(
self,
item: ResponseCacheModelVar,
) -> None:
with Session(self.engine) as session:
cached = self.RESPONSE_CACHE_SQL_MODEL(
market_id=item.market_id,
datetime_=item.created_at,
json_dump=item.model_dump_json(),
)
session.add(cached)
session.commit()
6 changes: 4 additions & 2 deletions labs_api/insights.py → labs_api/insights/insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
utcnow,
)

from labs_api.insights_cache import MarketInsightsResponseCache
from labs_api.models import MarketInsightsResponse
from labs_api.insights.insights_cache import (
MarketInsightsResponse,
MarketInsightsResponseCache,
)


# Don't observe the cached version, as it will always return the same thing that's already been observed.
Expand Down
61 changes: 61 additions & 0 deletions labs_api/insights/insights_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import typing as t
from datetime import datetime

from prediction_market_agent_tooling.gtypes import HexAddress
from prediction_market_agent_tooling.tools.tavily_storage.tavily_models import (
TavilyResponse,
TavilyResult,
)
from pydantic import BaseModel

from labs_api.cache import ResponseCache, ResponseCacheModel, ResponseCacheSQLModel


class MarketInsightResult(BaseModel):
url: str
title: str

@staticmethod
def from_tavily_result(tavily_result: TavilyResult) -> "MarketInsightResult":
return MarketInsightResult(url=tavily_result.url, title=tavily_result.title)


class MarketInsightsResponse(ResponseCacheModel):
summary: str | None
results: list[MarketInsightResult]

@property
def has_insights(self) -> bool:
return bool(self.summary or self.results)

@staticmethod
def from_tavily_response(
market_id: HexAddress,
created_at: datetime,
summary: str | None,
tavily_response: t.Union[TavilyResponse, None],
) -> "MarketInsightsResponse":
return MarketInsightsResponse(
market_id=market_id,
created_at=created_at,
summary=summary,
results=(
[
MarketInsightResult.from_tavily_result(result)
for result in tavily_response.results
]
if tavily_response
else []
),
)


class MarketInsightsResponseCacheModel(ResponseCacheSQLModel, table=True):
__tablename__ = "market_insights_response_cache"


class MarketInsightsResponseCache(
ResponseCache[MarketInsightsResponse, MarketInsightsResponseCacheModel]
):
RESPONSE_CACHE_MODEL = MarketInsightsResponse
RESPONSE_CACHE_SQL_MODEL = MarketInsightsResponseCacheModel
81 changes: 0 additions & 81 deletions labs_api/insights_cache.py

This file was deleted.

51 changes: 51 additions & 0 deletions labs_api/invalid/invalid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import fastapi
from loguru import logger
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import (
HexAddress,
OmenSubgraphHandler,
)
from prediction_market_agent_tooling.tools.is_invalid import is_invalid
from prediction_market_agent_tooling.tools.langfuse_ import observe
from prediction_market_agent_tooling.tools.utils import utcnow

from labs_api.invalid.invalid_cache import (
MarketInvalidResponse,
MarketInvalidResponseCache,
)


# Don't observe the cached version, as it will always return the same thing that's already been observed.
def market_invalid_cached(
market_id: HexAddress, cache: MarketInvalidResponseCache
) -> MarketInvalidResponse:
"""Returns `market_invalid`, but cached daily."""
if (cached := cache.find(market_id)) is not None:
return cached

else:
new = market_invalid(market_id)
if new.has_invalid:
cache.save(new)
return new


@observe()
def market_invalid(market_id: HexAddress) -> MarketInvalidResponse:
"""Returns market invalid for a given market on Omen."""
try:
market = OmenSubgraphHandler().get_omen_market_by_market_id(market_id)
except ValueError:
raise fastapi.HTTPException(
status_code=404, detail=f"Market with id `{market_id}` not found."
)
try:
invalid = is_invalid(market.question_title)
except Exception as e:
logger.error(f"Failed to get is_invalid for market `{market_id}`: {e}")
invalid = None
return MarketInvalidResponse(
market_id=market_id,
created_at=utcnow(),
invalid=invalid,
)
20 changes: 20 additions & 0 deletions labs_api/invalid/invalid_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from labs_api.cache import ResponseCache, ResponseCacheModel, ResponseCacheSQLModel


class MarketInvalidResponse(ResponseCacheModel):
invalid: bool | None

@property
def has_invalid(self) -> bool:
return self.invalid is not None


class MarketInvalidResponseCacheModel(ResponseCacheSQLModel, table=True):
__tablename__ = "market_invalid_response_cache"


class MarketInvalidResponseCache(
ResponseCache[MarketInvalidResponse, MarketInvalidResponseCacheModel]
):
RESPONSE_CACHE_MODEL = MarketInvalidResponse
RESPONSE_CACHE_SQL_MODEL = MarketInvalidResponseCacheModel
22 changes: 19 additions & 3 deletions labs_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@
from prediction_market_agent_tooling.loggers import logger

from labs_api.config import Config
from labs_api.insights import MarketInsightsResponse, market_insights_cached
from labs_api.insights_cache import MarketInsightsResponseCache
from labs_api.insights.insights import (
MarketInsightsResponse,
MarketInsightsResponseCache,
market_insights_cached,
)
from labs_api.invalid.invalid import (
MarketInvalidResponse,
MarketInvalidResponseCache,
market_invalid_cached,
)

HEX_ADDRESS_VALIDATOR = t.Annotated[
HexAddress,
Expand Down Expand Up @@ -41,7 +49,8 @@ async def lifespan(app: fastapi.FastAPI) -> t.AsyncIterator[None]:
allow_methods=["*"],
allow_headers=["*"],
)
market_insights_cache = MarketInsightsResponseCache()
market_insights_cache = MarketInsightsResponseCache(cache_expiry_days=3)
market_invalid_cache = MarketInvalidResponseCache(cache_expiry_days=None)

@app.get("/ping/")
def _ping() -> str:
Expand All @@ -56,6 +65,13 @@ def _market_insights(market_id: HEX_ADDRESS_VALIDATOR) -> MarketInsightsResponse
logger.info(f"Insights for `{market_id}`: {insights.model_dump()}")
return insights

@app.get("/market-invalid/")
def _market_invalid(market_id: HEX_ADDRESS_VALIDATOR) -> MarketInvalidResponse:
"""Returns whetever the market might be invalid."""
invalid = market_invalid_cached(market_id, market_invalid_cache)
logger.info(f"Invalid for `{market_id}`: {invalid.model_dump()}")
return invalid

logger.info("API created.")

return app
Expand Down
Loading

0 comments on commit 6ee0103

Please sign in to comment.