From 5d706c7a3808b1c21868bed53d551ffbac7d5970 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 13 Nov 2024 11:36:06 +0000 Subject: [PATCH] fix tests --- .../internal/inputs/indiadb/client.py | 4 ++-- .../internal/inputs/indiadb/conftest.py | 22 +++++++++++++++++++ .../internal/inputs/indiadb/test_indiadb.py | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/india_api/internal/inputs/indiadb/client.py b/src/india_api/internal/inputs/indiadb/client.py index 36a77d3..bea2bf2 100644 --- a/src/india_api/internal/inputs/indiadb/client.py +++ b/src/india_api/internal/inputs/indiadb/client.py @@ -195,7 +195,7 @@ def get_predicted_solar_power_production_for_location( forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, - model_name=model_name, + ml_model_name=model_name, ) def get_predicted_wind_power_production_for_location( @@ -224,7 +224,7 @@ def get_predicted_wind_power_production_for_location( forecast_horizon=forecast_horizon, forecast_horizon_minutes=forecast_horizon_minutes, smooth_flag=smooth_flag, - model_name=model_name, + ml_model_name=model_name, ) def get_actual_solar_power_production_for_location( diff --git a/src/india_api/internal/inputs/indiadb/conftest.py b/src/india_api/internal/inputs/indiadb/conftest.py index cf294d8..aa22369 100644 --- a/src/india_api/internal/inputs/indiadb/conftest.py +++ b/src/india_api/internal/inputs/indiadb/conftest.py @@ -6,6 +6,7 @@ import pytest from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL from pvsite_datamodel.read.user import get_user_by_email +from pvsite_datamodel.read.model import get_or_create_model from sqlalchemy import create_engine from sqlalchemy.orm import Session from testcontainers.postgres import PostgresContainer @@ -123,6 +124,23 @@ def generations(db_session, sites): @pytest.fixture() def forecast_values(db_session, sites): """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "pvnet_india") + +@pytest.fixture() +def forecast_values_wind(db_session, sites): + """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "windnet_india") + +@pytest.fixture() +def forecast_values_site(db_session, sites): + """Create some fake forecast values""" + + make_fake_forecast_values(db_session, sites, "pvnet_ad_sites") + + +def make_fake_forecast_values(db_session, sites, model_name): forecast_values = [] forecast_version: str = "0.0.0" @@ -134,6 +152,9 @@ def forecast_values(db_session, sites): # To make things trickier we make a second forecast at the same for one of the timestamps. timestamps = timestamps + timestamps[-1:] + # get model + ml_model = get_or_create_model(db_session, model_name) + for site in sites: for timestamp in timestamps: forecast: ForecastSQL = ForecastSQL( @@ -154,6 +175,7 @@ def forecast_values(db_session, sites): end_utc=timestamp + timedelta(minutes=horizon + duration), horizon_minutes=horizon, ) + forecast_value.ml_model = ml_model forecast_values.append(forecast_value) diff --git a/src/india_api/internal/inputs/indiadb/test_indiadb.py b/src/india_api/internal/inputs/indiadb/test_indiadb.py index c9b0791..7c8e28b 100644 --- a/src/india_api/internal/inputs/indiadb/test_indiadb.py +++ b/src/india_api/internal/inputs/indiadb/test_indiadb.py @@ -23,7 +23,7 @@ def client(engine, db_session): class TestIndiaDBClient: def test_get_predicted_wind_power_production_for_location( - self, client, forecast_values + self, client, forecast_values_wind ) -> None: locID = "testID" result = client.get_predicted_wind_power_production_for_location(locID) @@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location( assert isinstance(record, PredictedPower) def test_get_predicted_wind_power_production_for_location_raise_error( - self, client, forecast_values + self, client, forecast_values_wind ) -> None: with pytest.raises(Exception): @@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None: sites_from_api = client.get_sites(email="test2@test.com") assert len(sites_from_api) == 0 - def test_get_site_forecast(self, client, sites, forecast_values) -> None: + def test_get_site_forecast(self, client, sites, forecast_values_site) -> None: out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="test@test.com") assert len(out) > 0