Skip to content

Commit

Permalink
option to load test data from csv file
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Nov 5, 2024
1 parent 2f00ced commit 4f85463
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from functools import cached_property

import jsonref
import pandas as pd

import pyspark.sql.functions as sfn
from chispa import assert_df_equality
from delta.pip_utils import configure_spark_with_delta_pip
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

from chispa import assert_df_equality
from delta.pip_utils import configure_spark_with_delta_pip

from tempo.intervals import IntervalsDF
from tempo.tsdf import TSDF

Expand Down Expand Up @@ -43,11 +47,20 @@ def df_schema(self) -> str:
"""
return self.df["schema"]

def df_data(self) -> list:
def df_data(self) -> Union[list, pd.DataFrame]:
"""
:return: the data component of the test data
"""
return self.df["data"]
data = self.df["data"]
# return data literals (list of rows)
if isinstance(data, list):
return data
# load data from a csv file
elif isinstance(data, str):
csv_path = SparkTest.getTestDataFilePath(data, extension='')
return pd.read_csv(csv_path)
else:
raise ValueError(f"Invalid data type {type(data)}")

# TSDF metadata

Expand Down Expand Up @@ -234,7 +247,8 @@ def get_data_as_idf(self, name: str, convert_ts_col=True):

TEST_DATA_FOLDER = "unit_test_data"

def __getTestDataFilePath(self, test_file_name: str) -> str:
@classmethod
def getTestDataFilePath(cls, test_file_name: str, extension: str = '.json') -> str:
# what folder are we running from?
cwd = os.path.basename(os.getcwd())

Expand All @@ -251,7 +265,7 @@ def __getTestDataFilePath(self, test_file_name: str) -> str:
)

# return appropriate path
return f"{dir_path}/{self.TEST_DATA_FOLDER}/{test_file_name}.json"
return f"{dir_path}/{cls.TEST_DATA_FOLDER}/{test_file_name}{extension}"

def __loadTestData(self, test_case_path: str) -> dict:
"""
Expand All @@ -265,7 +279,7 @@ def __loadTestData(self, test_case_path: str) -> dict:
# load the test data file if it hasn't been loaded yet
if self.test_data_file is None:
# find our test data file
test_data_filename = self.__getTestDataFilePath(file_name)
test_data_filename = self.getTestDataFilePath(file_name)
if not os.path.isfile(test_data_filename):
warnings.warn(f"Could not load test data file {test_data_filename}")
self.test_data_file = {}
Expand Down

0 comments on commit 4f85463

Please sign in to comment.