diff --git a/pyproject.toml b/pyproject.toml index 9143976..c321539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = ["Programming Language :: Python :: 3"] dependencies = [ "cryptography >= 42.0.7", "fastapi >= 0.105.0", - "pvsite-datamodel >= 1.0.41", + "pvsite-datamodel >= 1.0.45", "pyjwt >= 2.8.0", "pyproj >= 3.3.0", "pytz >= 2023.3", diff --git a/src/india_api/internal/inputs/indiadb/client.py b/src/india_api/internal/inputs/indiadb/client.py index 2b09b2a..966f528 100644 --- a/src/india_api/internal/inputs/indiadb/client.py +++ b/src/india_api/internal/inputs/indiadb/client.py @@ -13,6 +13,7 @@ get_pv_generation_by_sites, get_user_by_email, get_sites_from_user, + get_site_by_uuid, ) from pvsite_datamodel.write.generation import insert_generation_values from pvsite_datamodel.sqlmodels import SiteAssetType, ForecastValueSQL @@ -103,6 +104,10 @@ def get_predicted_power_production_for_location( site = sites[0] + if site.ml_model is not None: + ml_model_name = site.ml_model.model_name + log.info(f"Using ml model {site.ml_model.model_name}") + # read actual generations values = get_latest_forecast_values_by_site( session, @@ -286,6 +291,12 @@ def get_site_forecast(self, site_uuid: str, email: str) -> list[internal.Predict with self._get_session() as session: check_user_has_access_to_site(session=session, email=email, site_uuid=site_uuid) + # get site and the get the ml model name + site = get_site_by_uuid(session=session, site_uuid=site_uuid) + if site.ml_model is not None: + ml_model_name = site.ml_model.model_name + log.info(f"Using ml model {site.ml_model.model_name}") + if isinstance(site_uuid, str): site_uuid = UUID(site_uuid)