Skip to content

Commit

Permalink
Add client fetch to auth dependency and create new dependency for asy…
Browse files Browse the repository at this point in the history
…nc stuff
  • Loading branch information
AndrewLester committed Apr 28, 2023
1 parent e4c0922 commit 23fd4cc
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 104 deletions.
19 changes: 13 additions & 6 deletions pv_site_api/_db_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import sqlalchemy as sa
import structlog
from fastapi import Depends
from pvsite_datamodel.read.generation import get_pv_generation_by_sites
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, InverterSQL, SiteSQL
from sqlalchemy.orm import Session, aliased
Expand All @@ -24,6 +25,7 @@
PVSiteMetadata,
SiteForecastValues,
)
from .session import get_session

logger = structlog.stdlib.get_logger()

Expand Down Expand Up @@ -60,12 +62,6 @@ def _get_forecasts_for_horizon(
return list(session.execute(stmt))


def _get_inverters_by_site(session: Session, site_uuid: str) -> list[Row]:
query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid)

return query.all()


def _get_latest_forecast_by_sites(
session: Session, site_uuids: list[str], start_utc: Optional[dt.datetime] = None
) -> list[Row]:
Expand Down Expand Up @@ -240,3 +236,14 @@ def does_site_exist(session: Session, site_uuid: str) -> bool:
session.execute(sa.select(SiteSQL).where(SiteSQL.site_uuid == site_uuid)).one_or_none()
is not None
)


def get_inverters_for_site(
site_uuid: str, session: Session = Depends(get_session)
) -> list[Row] | None:
"""Path dependency to get a list of inverters for a site, or None if the site doesn't exist"""
if not does_site_exist(session, site_uuid):
return None

query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid)
return query.all()
21 changes: 18 additions & 3 deletions pv_site_api/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import jwt
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pvsite_datamodel.sqlmodels import ClientSQL
from sqlalchemy.orm import Session

from .session import get_session

token_auth_scheme = HTTPBearer()

Expand All @@ -15,7 +19,11 @@ def __init__(self, domain: str, api_audience: str, algorithm: str):

self._jwks_client = jwt.PyJWKClient(f"https://{domain}/.well-known/jwks.json")

def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme)):
def __call__(
self,
auth_credentials: HTTPAuthorizationCredentials = Depends(token_auth_scheme),
session: Session = Depends(get_session),
):
token = auth_credentials.credentials

try:
Expand All @@ -24,7 +32,7 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke
raise HTTPException(status_code=401, detail=str(e))

try:
payload = jwt.decode(
jwt.decode(
token,
signing_key,
algorithms=self._algorithm,
Expand All @@ -34,4 +42,11 @@ def __call__(self, auth_credentials: HTTPAuthorizationCredentials = Depends(toke
except Exception as e:
raise HTTPException(status_code=401, detail=str(e))

return payload
if session is None:
return None

# @TODO: get client corresponding to auth
# See: https://github.com/openclimatefix/pv-site-api/issues/90
client = session.query(ClientSQL).first()
assert client is not None
return client
2 changes: 1 addition & 1 deletion pv_site_api/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def wrapper(*args, **kwargs): # noqa
route_variables = kwargs.copy()

# drop session and user
for var in ["session", "user"]:
for var in ["session", "user", "auth"]:
if var in route_variables:
route_variables.pop(var)

Expand Down
22 changes: 18 additions & 4 deletions pv_site_api/enode_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,35 @@ def __init__(
self._token_url = token_url
self._access_token = access_token

def auth_flow(self, request: httpx.Request):
def sync_auth_flow(self, request: httpx.Request):
# Add the Authorization header to the request using the current access token
request.headers["Authorization"] = f"Bearer {self._access_token}"
response = yield request

if response.status_code == 401:
# The access token is no longer valid, refresh it
token_response = yield self._build_refresh_request()
token_response.read()
self._update_access_token(token_response)
# Update the request's Authorization header with the new access token
request.headers["Authorization"] = f"Bearer {self._access_token}"
# Resend the request with the new access token
response = yield request
return response
yield request

async def async_auth_flow(self, request: httpx.Request):
# Add the Authorization header to the request using the current access token
request.headers["Authorization"] = f"Bearer {self._access_token}"
response = yield request

if response.status_code == 401:
# The access token is no longer valid, refresh it
token_response = yield self._build_refresh_request()
await token_response.aread()
self._update_access_token(token_response)
# Update the request's Authorization header with the new access token
request.headers["Authorization"] = f"Bearer {self._access_token}"
# Resend the request with the new access token
yield request

def _build_refresh_request(self):
basic_auth = httpx.BasicAuth(self._client_id, self._client_secret)
Expand All @@ -35,5 +50,4 @@ def _build_refresh_request(self):
return request

def _update_access_token(self, response):
response.read()
self._access_token = response.json()["access_token"]
Loading

0 comments on commit 23fd4cc

Please sign in to comment.