diff --git a/python/lsst/rubintv/analysis/service/commands/__init__.py b/python/lsst/rubintv/analysis/service/commands/__init__.py index 268e95b..43e4f70 100644 --- a/python/lsst/rubintv/analysis/service/commands/__init__.py +++ b/python/lsst/rubintv/analysis/service/commands/__init__.py @@ -21,4 +21,5 @@ from .butler import * from .db import * +from .file import * from .image import * diff --git a/python/lsst/rubintv/analysis/service/commands/file.py b/python/lsst/rubintv/analysis/service/commands/file.py new file mode 100644 index 0000000..923134a --- /dev/null +++ b/python/lsst/rubintv/analysis/service/commands/file.py @@ -0,0 +1,441 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import logging +import os +import shutil +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..command import BaseCommand + +if TYPE_CHECKING: + from ..data import DataCenter + +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB + + +def sanitize_path(base_path: str, user_path: list[str]) -> str: + """Sanitize and validate a user-provided path. + + Parameters + ---------- + base_path + The root directory that shouldn't be escaped. + user_path + List of path components provided by the user. + + Returns + ------- + result: + A sanitized absolute path, or None if the path is invalid. + """ + # Join the path components and normalize + full_path = os.path.normpath(os.path.join(base_path, *user_path)) + + # Check if the resulting path is within the base_path + if not full_path.startswith(os.path.abspath(base_path)): + raise ValueError(f"Invalid path: {full_path}") + + return full_path + + +class FileOperationError(Exception): + """Custom exception for file operations.""" + + pass + + +@dataclass(kw_only=True) +class LoadDirectoryCommand(BaseCommand): + """Load the files and sub directories contained in a directory. + + Attributes + ---------- + path + The path to the directory to list. + response_type + The type of response to send back to the client. + """ + + path: list[str] + response_type: str = "directory files" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + if not os.path.exists(full_path): + raise FileOperationError(f"The path '{full_path}' does not exist.") + + if not os.path.isdir(full_path): + raise FileOperationError(f"The path '{full_path}' is not a directory.") + + all_items = os.listdir(full_path) + files = [f for f in all_items if os.path.isfile(os.path.join(full_path, f))] + directories = [d for d in all_items if os.path.isdir(os.path.join(full_path, d))] + + logging.info(f"Directory contents listed: {full_path}") + return { + "path": self.path, + "files": sorted(files), + "directories": sorted(directories), + } + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class CreateDirectoryCommand(BaseCommand): + """Create a new directory. + + Attributes + ---------- + path + The path to the parent directory. + name + The name of the new directory to create. + """ + + path: list[str] + name: str + response_type: str = "directory created" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + parent_path = sanitize_path(data_center.user_path, self.path) + if parent_path is None: + raise FileOperationError("Invalid path") + + full_path = os.path.join(parent_path, self.name) + if not full_path.startswith(data_center.user_path): + raise FileOperationError("Invalid directory name") + + os.makedirs(full_path, exist_ok=True) + logging.info(f"Directory created: {full_path}") + return {"path": full_path, "parent_path": self.path, "name": self.name} + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class RenameFileCommand(BaseCommand): + """Rename a file or directory. + + Attributes + ---------- + path + The path to the file or directory to rename. + new_name + The new name to assign to the file or directory. + """ + + path: list[str] + new_name: str + response_type: str = "file renamed" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + new_path = os.path.join(os.path.dirname(full_path), self.new_name) + if not new_path.startswith(data_center.user_path): + raise FileOperationError("Invalid new name") + + if not os.path.exists(full_path): + raise FileOperationError(f"The source path '{full_path}' does not exist.") + + if os.path.exists(new_path): + raise FileOperationError(f"The new path '{new_path}' already exists. Cannot overwrite.") + + os.rename(full_path, new_path) + logging.info(f"File renamed: {full_path} to {new_path}") + return {"new_path": new_path, "new_name": self.new_name, "path": self.path} + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class DeleteFileCommand(BaseCommand): + """Delete a file or directory. + + Attributes + ---------- + path + The path to the file or directory to delete. + """ + + path: list[str] + response_type: str = "file deleted" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + if not os.path.exists(full_path): + raise FileOperationError(f"The path '{full_path}' does not exist.") + + if os.path.isfile(full_path): + os.remove(full_path) + logging.info(f"File deleted: {full_path}") + return {"deleted_path": self.path, "type": "file"} + elif os.path.isdir(full_path): + shutil.rmtree(full_path) + logging.info(f"Directory deleted: {full_path}") + return {"deleted_path": self.path, "type": "directory"} + else: + raise FileOperationError(f"The path '{full_path}' is neither a file nor a directory.") + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class DuplicateFileCommand(BaseCommand): + """Duplicate a file or directory. + + Attributes + ---------- + path + The path to the file or directory to duplicate + """ + + path: list[str] + response_type: str = "file duplicated" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + if not os.path.exists(full_path): + raise FileOperationError(f"The path '{full_path}' does not exist.") + + dir_path = os.path.dirname(full_path) + base_name = os.path.basename(full_path) + new_path = os.path.join(dir_path, f"{base_name}_copy") + counter = 1 + + while os.path.exists(new_path): + new_path = os.path.join(dir_path, f"{base_name}_copy_{counter}") + counter += 1 + + new_filename = os.path.basename(new_path) + + if os.path.isfile(full_path): + shutil.copy2(full_path, new_path) + logging.info(f"File duplicated: {full_path} to {new_path}") + return { + "path": self.path[:-1], + "old_name": self.path[-1], + "new_filename": new_filename, + "type": "file", + } + elif os.path.isdir(full_path): + shutil.copytree(full_path, new_path) + logging.info(f"Directory duplicated: {full_path} to {new_path}") + return {"new_path": new_path, "type": "directory"} + else: + raise FileOperationError(f"The path '{full_path}' is neither a file nor a directory.") + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class MoveFileCommand(BaseCommand): + """Move a file or directory. + + Attributes + ---------- + source_path + The path to the file or directory to move. + destination_path + The path to move the file or directory to. + """ + + source_path: list[str] + destination_path: list[str] + response_type: str = "file moved" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_source_path = sanitize_path(data_center.user_path, self.source_path) + destination_path = sanitize_path(data_center.user_path, self.destination_path) + full_destination_path = os.path.join(destination_path, os.path.basename(full_source_path)) + + if full_source_path is None or full_destination_path is None: + raise FileOperationError("Invalid path") + + if not os.path.exists(full_source_path): + raise FileOperationError(f"The source path '{full_source_path}' does not exist.") + + if os.path.exists(full_destination_path): + raise FileOperationError( + f"The destination path '{full_destination_path}' already exists. " + "Use overwrite=True to overwrite." + ) + + os.makedirs(os.path.dirname(full_destination_path), exist_ok=True) + shutil.move(full_source_path, full_destination_path) + + logging.info(f"File moved: {full_source_path} to {full_destination_path}") + return { + "destination_path": self.destination_path, + "source_path": self.source_path, + "type": "file" if os.path.isfile(full_destination_path) else "directory", + } + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class SaveFileCommand(BaseCommand): + """Save a file with the provided content. + + Attributes + ---------- + path + The path to the file to save. + content + The content to write to the file. + """ + + path: list[str] + content: str + response_type: str = "file saved" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + if os.path.exists(full_path) and os.path.isdir(full_path): + raise FileOperationError(f"The path '{full_path}' already exists as a directory.") + + with open(full_path, "w", encoding="utf-8") as f: + f.write(self.content) + + logging.info(f"File saved: {full_path}") + return {"saved_path": full_path} + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +@dataclass(kw_only=True) +class LoadFileCommand(BaseCommand): + """Load the contents of a file. + + Attributes + ---------- + path + The path to the file to load. + max_size + The maximum allowed size of the file in bytes. + This prevents accendentally loading large files. + """ + + path: list[str] + max_size: int = MAX_FILE_SIZE + response_type: str = "file content" + + def build_contents(self, data_center: DataCenter) -> dict: + try: + full_path = sanitize_path(data_center.user_path, self.path) + if full_path is None: + raise FileOperationError("Invalid path") + + if not os.path.exists(full_path): + raise FileOperationError(f"The file '{full_path}' does not exist.") + + if not os.path.isfile(full_path): + raise FileOperationError(f"The path '{full_path}' is not a file.") + + file_size = os.path.getsize(full_path) + if file_size > self.max_size: + raise FileOperationError( + f"The file '{full_path}' exceeds the maximum allowed size of {self.max_size} bytes." + ) + + with open(full_path, "r", encoding="utf-8") as f: + content = f.read() + + logging.info(f"File loaded: {full_path}") + return {"content": content, "path": full_path, "size": file_size, "encoding": "utf-8"} + except FileOperationError as e: + logging.error(f"File operation error: {str(e)}") + return {"error": str(e)} + except UnicodeDecodeError: + logging.error(f"Unicode decode error: {full_path}") + return { + "error": f"Unable to decode '{full_path}' as UTF-8. " + "The file might be binary or use a different encoding." + } + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + return {"error": f"An unexpected error occurred: {str(e)}"} + + +# Register the commands +LoadDirectoryCommand.register("list directory") +CreateDirectoryCommand.register("create directory") +RenameFileCommand.register("rename") +DeleteFileCommand.register("delete") +DuplicateFileCommand.register("duplicate") +MoveFileCommand.register("move") +SaveFileCommand.register("save") +LoadFileCommand.register("load") diff --git a/python/lsst/rubintv/analysis/service/data.py b/python/lsst/rubintv/analysis/service/data.py index 196d569..e85394c 100644 --- a/python/lsst/rubintv/analysis/service/data.py +++ b/python/lsst/rubintv/analysis/service/data.py @@ -143,16 +143,19 @@ class DataCenter: An EFD client instance. """ + user_path: str schemas: dict[str, ConsDbSchema] butlers: dict[str, Butler] | None = None efd_client: EfdClient | None = None def __init__( self, + user_path: str, schemas: dict[str, ConsDbSchema], butlers: dict[str, Butler] | None = None, efd_client: EfdClient | None = None, ): + self.user_path = user_path self.schemas = schemas self.butlers = butlers self.efdClient = efd_client diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py index 08a6b58..912cad7 100644 --- a/python/lsst/rubintv/analysis/service/query.py +++ b/python/lsst/rubintv/analysis/service/query.py @@ -38,6 +38,19 @@ class QueryError(Exception): pass +left_operator = { + "eq": "eq", + "ne": "ne", + "lt": "gt", + "lte": "gte", + "gt": "lt", + "gte": "lte", + "starts with": "starts with", + "ends with": "ends with", + "contains": "contains", +} + + @dataclass class QueryResult: """The result of a query. @@ -87,10 +100,49 @@ def from_dict(query_dict: dict[str, Any]) -> Query: to initialize the query. """ try: - if query_dict["name"] == "EqualityQuery": - return EqualityQuery.from_dict(query_dict["content"]) - elif query_dict["name"] == "ParentQuery": - return ParentQuery.from_dict(query_dict["content"]) + if query_dict["type"] == "EqualityQuery": + if "leftOperator" not in query_dict: + return EqualityQuery.from_dict( + { + "column": query_dict["field"], + "operator": query_dict["rightOperator"], + "value": query_dict["rightValue"], + } + ) + if "rightOperator" not in query_dict: + return EqualityQuery.from_dict( + { + "column": query_dict["field"], + "operator": left_operator[query_dict["leftOperator"]], + "value": query_dict["leftValue"], + } + ) + return ParentQuery.from_dict( + { + "children": [ + { + "type": "EqualityQuery", + "field": query_dict["field"], + "leftOperator": query_dict["leftOperator"], + "leftValue": query_dict["leftValue"], + }, + { + "type": "EqualityQuery", + "field": query_dict["field"], + "rightOperator": query_dict["rightOperator"], + "rightValue": query_dict["rightValue"], + }, + ], + "operator": "AND", + } + ) + elif query_dict["type"] == "ParentQuery": + return ParentQuery.from_dict( + { + "children": query_dict["children"], + "operator": query_dict["operator"], + } + ) except Exception: raise QueryError(f"Failed to parse query: {query_dict}") @@ -136,7 +188,11 @@ def __call__(self, database: ConsDbSchema) -> QueryResult: @staticmethod def from_dict(query_dict: dict[str, Any]) -> EqualityQuery: - return EqualityQuery(**query_dict) + return EqualityQuery( + column=f'{query_dict["column"]["schema"]}.{query_dict["column"]["name"]}', + operator=query_dict["operator"], + value=query_dict["value"], + ) class ParentQuery(Query): diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py index 1d63bd7..1352cd0 100644 --- a/scripts/rubintv_worker.py +++ b/scripts/rubintv_worker.py @@ -37,7 +37,11 @@ default_joins = os.path.join(pathlib.Path(__file__).parent.absolute(), "joins.yaml") logger = logging.getLogger("lsst.rubintv.analysis.server.worker") sdm_schemas_path = os.path.join(os.path.expandvars("$SDM_SCHEMAS_DIR"), "yml") -credentials_path = os.path.join(os.path.expanduser("~"), ".lsst", "postgres-credentials.txt") +prod_credentials_path = os.path.join("/etc/secrets", "postgres-credentials.txt") +test_credentials_path = os.path.join(os.path.expanduser("~"), ".lsst", "postgres-credentials.txt") +summit_users_path = "/usr/share/worker/configs" +usdf_users_path = "/usr/share/worker/configs" +dev_users_path = "/sdf/home/f/fred3m/u/data/dev_users" class UniversalToVisit(DataMatch): @@ -62,7 +66,7 @@ def main(): "--location", default="usdf", type=str, - help="Location of the worker (either 'summit' or 'usdf')", + help="Location of the worker (either 'summit', 'usdf', or 'dev')", ) parser.add_argument( "--log", @@ -111,8 +115,16 @@ def main(): server = "" if args.location.lower() == "summit": server = config["locations"]["summit"] + credentials_path = prod_credentials_path + user_path = summit_users_path elif args.location.lower() == "usdf": server = config["locations"]["usdf"] + credentials_path = prod_credentials_path + user_path = usdf_users_path + elif args.location.lower() == "dev": + server = config["locations"]["usdf"] + credentials_path = test_credentials_path + user_path = dev_users_path else: raise ValueError(f"Invalid location: {args.location}, must be either 'summit' or 'usdf'") @@ -153,7 +165,7 @@ def main(): # Create the DataCenter that keeps track of all data sources. # This will have to be updated every time we want to # change/add a new data source. - data_center = DataCenter(schemas=schemas, butlers=butlers, efd_client=efd_client) + data_center = DataCenter(schemas=schemas, butlers=butlers, efd_client=efd_client, user_path=user_path) # Run the client and connect to rubinTV via websockets logger.info("Initializing worker") diff --git a/tests/test_command.py b/tests/test_command.py index 58640eb..467aec3 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -91,12 +91,13 @@ def test_load_columns_with_query(self): "exposure.dec", ], "query": { - "name": "EqualityQuery", - "content": { - "column": "visit1_quicklook.exp_time", - "operator": "eq", - "value": 30, + "type": "EqualityQuery", + "field": { + "schema": "visit1_quicklook", + "name": "exp_time", }, + "rightOperator": "eq", + "rightValue": 30, }, }, } diff --git a/tests/test_query.py b/tests/test_query.py index b97da65..0f77a9b 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -86,21 +86,23 @@ def test_database_query(self): # dec > 0 (and is not None) query1 = { - "name": "EqualityQuery", - "content": { - "column": "exposure.dec", - "operator": "gt", - "value": 0, + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "dec", }, + "leftOperator": "lt", + "leftValue": 0, } # ra > 60 (and is not None) query2 = { - "name": "EqualityQuery", - "content": { - "column": "exposure.ra", - "operator": "gt", - "value": 60, + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "ra", }, + "leftOperator": "lt", + "leftValue": 60, } # Test 1: dec > 0 (and is not None) @@ -117,11 +119,9 @@ def test_database_query(self): # Test 2: dec > 0 and ra > 60 (and neither is None) query = { - "name": "ParentQuery", - "content": { - "operator": "AND", - "children": [query1, query2], - }, + "type": "ParentQuery", + "operator": "AND", + "children": [query1, query2], } result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, False, False, True, True]] @@ -135,20 +135,16 @@ def test_database_query(self): # Test 3: dec <= 0 or ra > 60 (and neither is None) query = { - "name": "ParentQuery", - "content": { - "operator": "OR", - "children": [ - { - "name": "ParentQuery", - "content": { - "operator": "NOT", - "children": [query1], - }, - }, - query2, - ], - }, + "type": "ParentQuery", + "operator": "OR", + "children": [ + { + "type": "ParentQuery", + "operator": "NOT", + "children": [query1], + }, + query2, + ], } result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) @@ -163,11 +159,9 @@ def test_database_query(self): # Test 4: dec > 0 XOR ra > 60 query = { - "name": "ParentQuery", - "content": { - "operator": "XOR", - "children": [query1, query2], - }, + "type": "ParentQuery", + "operator": "XOR", + "children": [query1, query2], } result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, False, False, False, False]] @@ -184,12 +178,13 @@ def test_database_string_query(self): # Test equality query = { - "name": "EqualityQuery", - "content": { - "column": "exposure.physical_filter", - "operator": "eq", - "value": "DECam r-band", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "physical_filter", }, + "rightOperator": "eq", + "rightValue": "DECam r-band", } result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, True, False, False, False]] @@ -202,12 +197,13 @@ def test_database_string_query(self): # Test "startswith" query = { - "name": "EqualityQuery", - "content": { - "column": "exposure.physical_filter", - "operator": "startswith", - "value": "DECam", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "physical_filter", }, + "rightOperator": "startswith", + "rightValue": "DECam", } result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, True, True, True, True]] @@ -220,12 +216,13 @@ def test_database_string_query(self): # Test "endswith" query = { - "name": "EqualityQuery", - "content": { - "column": "exposure.physical_filter", - "operator": "endswith", - "value": "r-band", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "physical_filter", }, + "rightOperator": "endswith", + "rightValue": "r-band", } result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, True, False, False, False]] @@ -238,12 +235,13 @@ def test_database_string_query(self): # Test "like" query = { - "name": "EqualityQuery", - "content": { - "column": "exposure.physical_filter", - "operator": "contains", - "value": "T r", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "physical_filter", }, + "rightOperator": "contains", + "rightValue": "T r", } result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, False, False, False, False]] @@ -259,12 +257,13 @@ def test_database_datatime_query(self): # Test < query1 = { - "name": "EqualityQuery", - "content": { - "column": "exposure.obs_start", - "operator": "lt", - "value": "2023-05-19 23:23:23", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "obs_start", }, + "rightOperator": "lt", + "rightValue": "2023-05-19 23:23:23", } result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query1)) truth = data[[True, True, True, False, False, True, True, True, True, True]] @@ -277,12 +276,13 @@ def test_database_datatime_query(self): # Test > query2 = { - "name": "EqualityQuery", - "content": { - "column": "exposure.obs_start", - "operator": "gt", - "value": "2023-05-01 23:23:23", + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "obs_start", }, + "leftOperator": "lt", + "leftValue": "2023-05-01 23:23:23", } result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query2)) truth = data[[True, True, True, True, True, False, False, False, False, False]] @@ -295,11 +295,9 @@ def test_database_datatime_query(self): # Test in range query3 = { - "name": "ParentQuery", - "content": { - "operator": "AND", - "children": [query1, query2], - }, + "type": "ParentQuery", + "operator": "AND", + "children": [query1, query2], } result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query3)) truth = data[[True, True, True, False, False, False, False, False, False, False]] @@ -322,29 +320,29 @@ def test_multiple_table_query(self): # dec > 0 (and is not None) query1 = { - "name": "EqualityQuery", - "content": { - "column": "exposure.dec", - "operator": "gt", - "value": 0, + "type": "EqualityQuery", + "field": { + "schema": "exposure", + "name": "dec", }, + "leftOperator": "lt", + "leftValue": 0, } # exposure time == 30 (and is not None) query2 = { - "name": "EqualityQuery", - "content": { - "column": "visit1_quicklook.exp_time", - "operator": "eq", - "value": 30, + "type": "EqualityQuery", + "field": { + "schema": "visit1_quicklook", + "name": "exp_time", }, + "rightOperator": "eq", + "rightValue": 30, } # Intersection of the two queries query3 = { - "name": "ParentQuery", - "content": { - "operator": "AND", - "children": [query1, query2], - }, + "type": "ParentQuery", + "operator": "AND", + "children": [query1, query2], } valid = ( diff --git a/tests/utils.py b/tests/utils.py index 828a6e0..9008b60 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -218,7 +218,7 @@ def setUp(self): # Create the datacenter self.database = ConsDbSchema(schema=schema, engine=engine, join_templates=joins) - self.data_center = DataCenter(schemas={"testdb": self.database}) + self.data_center = DataCenter(schemas={"testdb": self.database}, user_path="") def tearDown(self) -> None: self.db_file.close()