Skip to content

Commit

Permalink
Merge pull request #8 from ecmwf-projects/retrieve_session_fix
Browse files Browse the repository at this point in the history
Close the session before downloading the data
  • Loading branch information
aperezpredictia authored Feb 26, 2024
2 parents 8ef907f + c915367 commit 6c16f59
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 43 deletions.
6 changes: 3 additions & 3 deletions cdsobs/cli/_catalogue_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def list_catalogue(
# with pagination (50 per page)
results = list_catalogue_(session, filters, page)

if len(results) == 0:
raise RuntimeError("No catalogue entries found for these parameters.")
if len(results) == 0:
raise RuntimeError("No catalogue entries found for these parameters.")

print_db_results(results, print_format)
print_db_results(results, print_format)


def list_catalogue_(
Expand Down
16 changes: 7 additions & 9 deletions cdsobs/cli/_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from cdsobs.cli._utils import CliException, ConfigNotFound, config_yml_typer
from cdsobs.config import validate_config
from cdsobs.observation_catalogue.database import get_session
from cdsobs.retrieve.api import retrieve_observations
from cdsobs.retrieve.models import RetrieveArgs
from cdsobs.storage import S3Client
Expand Down Expand Up @@ -61,13 +60,12 @@ def retrieve(
raise ConfigNotFound()
config = validate_config(cdsobs_config_yml)
s3_client = S3Client.from_config(config.s3config)
with get_session(config.catalogue_db) as session:
output_file = retrieve_observations(
session,
s3_client.public_url_base,
retrieve_args,
output_dir,
size_limit,
)
output_file = retrieve_observations(
config.catalogue_db.get_url(),
s3_client.public_url_base,
retrieve_args,
output_dir,
size_limit,
)
console = Console()
console.print(f"[green] Successfully downloaded {output_file} [/green]")
6 changes: 1 addition & 5 deletions cdsobs/forms_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ def get_variables_json(dataset: str, output_path: Path) -> Path:


def get_constraints_json(session, output_path: Path, dataset) -> Path:
"""
JSON file with the constraints in compressed form.
Beware this in the need of some optimization (may be resource heavy).
"""
"""JSON file with the constraints in compressed form."""
# This is probably slow, can it be improved?
catalogue_entries = get_catalogue_entries_stream(session, dataset)
merged_constraints = merged_constraints_table(catalogue_entries)
Expand Down
21 changes: 11 additions & 10 deletions cdsobs/retrieve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pandas
import xarray
from fsspec.implementations.http import HTTPFileSystem
from sqlalchemy.orm import Session

from cdsobs.cdm.lite import cdm_lite_variables
from cdsobs.constants import TIME_UNITS_REFERENCE_DATE
Expand All @@ -27,14 +26,14 @@
from cdsobs.retrieve.retrieve_services import estimate_data_size, ezclump
from cdsobs.service_definition.api import get_service_definition
from cdsobs.utils.logutils import SizeError, get_logger
from cdsobs.utils.utils import get_code_mapping
from cdsobs.utils.utils import get_code_mapping, get_database_session

logger = get_logger(__name__)
MAX_NUMBER_OF_GROUPS = 10


def retrieve_observations(
session: Session,
catalogue_url: str,
storage_url: str,
retrieve_args: RetrieveArgs,
output_dir: Path,
Expand All @@ -45,8 +44,9 @@ def retrieve_observations(
Parameters
----------
session:
Session in the catalogue database
catalogue_url:
URL of the catalogue database including credentials, in the form of
"postgresql+psycopg2://someuser:somepass@hostname:port/catalogue"
storage_url:
Storage URL
retrieve_args :
Expand All @@ -58,11 +58,12 @@ def retrieve_observations(
"""
logger.info("Starting retrieve pipeline.")
# Query the storage to get the URLS of the files that contain the data requested
catalogue_repository = CatalogueRepository(session)
entries = _get_catalogue_entries(catalogue_repository, retrieve_args)
object_urls = _get_urls_and_check_size(
entries, retrieve_args, size_limit, storage_url
)
with get_database_session(catalogue_url) as session:
catalogue_repository = CatalogueRepository(session)
entries = _get_catalogue_entries(catalogue_repository, retrieve_args)
object_urls = _get_urls_and_check_size(
entries, retrieve_args, size_limit, storage_url
)
# Get the path of the output file
output_path_netcdf = _get_output_path(output_dir, retrieve_args.dataset, "netCDF")
# First we always write the netCDF-lite file
Expand Down
2 changes: 1 addition & 1 deletion cdsobs/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _sanity_check_dataset(
check_if_missing_in_object_storage(catalogue_repo, s3_client, dataset_name)
# Retrieve and check output
output_path = retrieve_observations(
session,
config.catalogue_db.get_url(),
s3_client.public_url_base,
retrieve_args,
Path(tmpdir),
Expand Down
4 changes: 1 addition & 3 deletions tests/cli/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_cli_make_production(verbose):
assert result.exit_code == 0


@pytest.mark.skip(reason="this test does not reset db after running")
# @pytest.mark.skip(reason="this test does not reset db after running")
def test_cli_retrieve(tmp_path, test_repository):
runner = CliRunner()
test_json_str = """[
Expand Down Expand Up @@ -61,8 +61,6 @@ def test_cli_retrieve(tmp_path, test_repository):
CONFIG_YML,
"--output-dir",
str(tmp_path),
"--np",
"2",
]
result = runner.invoke(
app,
Expand Down
11 changes: 7 additions & 4 deletions tests/cli/test_catalogue_explorer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from typer.testing import CliRunner

from cdsobs.cli._catalogue_explorer import list_catalogue_
Expand All @@ -8,10 +9,11 @@
runner = CliRunner()


def test_list_catalogue(test_session, test_repository):
@pytest.mark.parametrize("print_format", ["table", "json"])
def test_list_catalogue(test_session, test_repository, print_format):
result = runner.invoke(
app,
["list-catalogue", "-c", CONFIG_YML],
["list-catalogue", "-c", CONFIG_YML, "--print-format", print_format],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand All @@ -26,10 +28,11 @@ def test_catalogue_dataset_info(test_session, test_repository):
assert result.exit_code == 0


def test_list_datasets():
@pytest.mark.parametrize("print_format", ["table", "json"])
def test_list_datasets(print_format):
result = runner.invoke(
app,
["list-datasets", "-c", CONFIG_YML, "--print-format", "json"],
["list-datasets", "-c", CONFIG_YML, "--print-format", print_format],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand Down
10 changes: 5 additions & 5 deletions tests/retrieve/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from cdsobs.config import CDSObsConfig
from cdsobs.constants import CONFIG_YML
from cdsobs.observation_catalogue.database import get_session
from cdsobs.retrieve.api import retrieve_observations
from cdsobs.retrieve.models import RetrieveArgs
from cdsobs.storage import S3Client
Expand All @@ -23,7 +22,9 @@


@pytest.mark.parametrize("oformat,dataset_source,time_coverage", PARAMETRIZE_VALUES)
def test_retrieve(test_repository, tmp_path, oformat, dataset_source, time_coverage):
def test_retrieve(
test_repository, test_config, tmp_path, oformat, dataset_source, time_coverage
):
dataset_name = "insitu-observations-woudc-ozone-total-column-and-profiles"
start_year, end_year = get_test_years(dataset_source)
if dataset_source == "OzoneSonde":
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_retrieve(test_repository, tmp_path, oformat, dataset_source, time_cover
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start = datetime.now()
output_file = retrieve_observations(
test_repository.catalogue_repository.session,
test_config.catalogue_db.get_url(),
test_repository.s3_client.base,
retrieve_args,
tmp_path,
Expand Down Expand Up @@ -86,10 +87,9 @@ def test_retrieve_cuon():
],
}
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
session = get_session(test_config.catalogue_db)
s3_client = S3Client.from_config(test_config.s3config)
output_file = retrieve_observations(
session,
test_config.catalogue_db.get_url(),
s3_client.base,
retrieve_args,
Path("/tmp"),
Expand Down
7 changes: 4 additions & 3 deletions tests/system/1_year_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
s3_client = S3Client.from_config(config.s3config)
start_time = time.perf_counter()
catalogue_url = config.catalogue_db.get_url()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand Down Expand Up @@ -127,7 +128,7 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start_time = time.perf_counter()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand All @@ -151,7 +152,7 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start_time = time.perf_counter()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand Down

0 comments on commit 6c16f59

Please sign in to comment.