Skip to content

Commit

Permalink
Reworked generic loaders, and started working on cache management.
Browse files Browse the repository at this point in the history
  • Loading branch information
Suchismit4 committed Dec 28, 2024
1 parent db996f8 commit 9ef1744
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 288 deletions.
3 changes: 2 additions & 1 deletion example_compustata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def main():
'data_path': 'wrds/equity/compustat',
'config': {
'columns_to_read': columns_to_read,
'freq': 'A',
'num_processes': 16,
'filters': {
'indfmt': 'INDL',
'datafmt': 'STD',
'popsrc': 'D',
'consol': 'C',
'datadate': ('>=', '1959-01-01')
'date': ('>=', '1959-01-01')
},
}
}
Expand Down
17 changes: 9 additions & 8 deletions example_openbb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ def main():
"start_date": "2020-01-01",
"end_date": "2021-01-01"
}
},
{
'data_path': 'wrds/equity/crsp',
'config': {
'num_processes': 16,
}
}])
}
# {
# 'data_path': 'wrds/equity/crsp',
# 'config': {
# 'num_processes': 16,
# 'freq': 'D'
# }
# }
])

print(dataset)


if __name__ == "__main__":
main()
Binary file modified src/data/abstracts/__pycache__/base.cpython-312.pyc
Binary file not shown.
168 changes: 119 additions & 49 deletions src/data/abstracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import xarray as xr
import json
from typing import Dict, Any, Union
from typing import Dict, Any, Union, Optional
from pathlib import Path
import hashlib
from abc import ABC, abstractmethod
Expand All @@ -17,8 +17,8 @@ class BaseDataSource(ABC):
Base class for handling data sources configuration and path management.
This class provides the foundational logic for accessing different data sources
defined in the paths.yaml configuration file. It handles both online and offline
data sources and manages the interaction with the cache system.
defined in the paths.yaml configuration file. It also handles interactions with
the cache system, including saving/loading xarray Datasets as NetCDF.
"""

def __init__(self, data_path: str):
Expand All @@ -37,6 +37,7 @@ def _apply_filters(self, df: pd.DataFrame, filters: Dict[str, Any]) -> pd.DataFr
Returns:
pd.DataFrame: The filtered DataFrame.
"""
print(df.columns)
for column, condition in filters.items():
if isinstance(condition, tuple) or isinstance(condition, list) \
and len(condition) == 2:
Expand All @@ -63,73 +64,142 @@ def _apply_filters(self, df: pd.DataFrame, filters: Dict[str, Any]) -> pd.DataFr

def get_cache_path(self, **params) -> Path:
"""
Generate a cache file path based on the data path and parameters.
Generate a base 'cache path' (WITHOUT extension) based on self.data_path + hashed params.
We append '.nc' or '.json' (for NetCDF caching) or '.parquet' (for DataFrame caching).
Example:
- data_path='wrds/equity/compustat'
- cache_root='~/data/cache'
=> ~/data/cache/wrds/equity/compustat/{md5hash_of_params}
Args:
**params: Parameters used in the data loading function.
**params: Arbitrary parameters to be hashed into the cache filename.
Returns:
Path: The corresponding cache file path.
Path: The fully-qualified path (no file extension).
"""
# Serialize the parameters to a JSON-formatted string

# Serialize the parameters
params_string = json.dumps(params, sort_keys=True)
# Generate a hash of the parameters string

# Create an MD5 hash of the parameters
params_hash = hashlib.md5(params_string.encode('utf-8')).hexdigest()
# Create the cache path using the data path and hash
cache_path = os.path.join(
self.cache_root,
self.data_path.strip('/'), # Remove leading/trailing slashes
params_hash
)
return Path(f"{cache_path}.parquet")

# Build the subdirectory path, e.g. ~/data/cache/wrds/equity/compustat
sub_dir = self.data_path.strip('/') # remove leading/trailing slashes

cache_dir = os.path.join(self.cache_root, sub_dir)
os.makedirs(cache_dir, exist_ok=True)

# Return something like ~/data/cache/wrds/equity/compustat/<hash>
base_path = os.path.join(cache_dir, params_hash)
return Path(base_path)

def check_cache_netcdf(self, base_path: Path) -> bool:
"""
Check if the NetCDF (.nc) file exists at this path.
"""
netcdf_path = base_path.with_suffix('.nc')
return netcdf_path.exists()

def _metadata_matches(self, metadata_path: Path, request_params: Dict[str, Any]) -> bool:
"""
Compare an existing JSON metadata file with the requested params.
If they match exactly, return True; otherwise False.
def check_cache(self, cache_path: Path) -> bool:
"""Check if valid cache exists for the given path."""
return cache_path.exists()
Args:
metadata_path (Path): Path to the .json file containing cached metadata.
request_params (Dict[str, Any]): The parameters requested for the current load.
def load_from_cache(self, cache_path: Path, frequency: FrequencyType = FrequencyType.DAILY) -> Union[xr.Dataset, None]:
Returns:
bool: True if the metadata file exists and matches request_params, False otherwise.
"""
if not metadata_path.exists():
return False
try:
with open(metadata_path, 'r') as f:
cached_params = json.load(f)
# Direct comparison
return cached_params == request_params
except Exception as e:
print(f"Error reading metadata file {metadata_path}: {e}")
return False

def load_from_cache(
self,
base_path: Path,
request_params: Dict[str, Any],
frequency: FrequencyType = FrequencyType.DAILY
) -> Optional[xr.Dataset]:
"""
Load data from cache file.
Load an xarray.Dataset from a NetCDF (.nc) file, if it exists and matches params.
We also check a .json file with the same base path to ensure the parameters match.
Args:
cache_path: Path to the cache file.
base_path (Path): The base path (no extension). We'll look for <base>.nc and <base>.json.
request_params (Dict[str, Any]): The params used for generating the data. We'll compare
these to what's stored in the .json sidecar.
frequency (FrequencyType): Frequency type to assign to the dataset if needed.
Returns:
xr.Dataset or None: The loaded dataset, or None if loading failed.
xr.Dataset or None: The loaded Dataset if everything matches; None otherwise.
"""
if self.check_cache(cache_path):
try:
# Load the parquet file into a pandas DataFrame
df = pd.read_parquet(cache_path)
data = self._convert_to_xarray(df,
list(df.columns.drop(['date', 'identifier'])),
frequency=frequency)
return data
except Exception as e:
print(f"Failed to load from cache: {e}")
return None
else:
netcdf_path = base_path.with_suffix('.nc')
metadata_path = base_path.with_suffix('.json')

if not netcdf_path.exists():
return None

# Check if metadata matches
if not self._metadata_matches(metadata_path, request_params):
return None

# TODO: Better caching.
def save_to_cache(self, df: pd.DataFrame, cache_path: Path, params: dict):
try:
ds = xr.load_dataset(netcdf_path) # or xr.open_dataset, either is fine
# Optionally you can confirm ds.attrs or other checks, but we'll skip that here.
return ds
except Exception as e:
print(f"Failed to load from NetCDF: {e}")
return None

def save_to_cache(
self,
ds: xr.Dataset,
base_path: Path,
params: Dict[str, Any]
) -> None:
"""
Save data to cache file.
Save an xarray.Dataset to a NetCDF (.nc) file, plus store parameters in a .json sidecar.
Args:
df: The DataFrame to save.
cache_path: Path to the cache file.
params: Parameters used in data loading, saved as metadata.
ds (xr.Dataset): The dataset to save.
base_path (Path): The base path (no extension).
params (dict): Parameters used to generate this dataset, stored in the JSON.
"""
# Save the DataFrame to parquet
cache_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(cache_path)

# Save metadata
metadata_path = cache_path.with_suffix('.json')
with open(metadata_path, 'w') as f:
json.dump(params, f)

netcdf_path = base_path.with_suffix('.nc')
metadata_path = base_path.with_suffix('.json')
netcdf_path.parent.mkdir(parents=True, exist_ok=True)

try:
# Add params to ds.attrs
ds.attrs.update(params)

ds.to_netcdf(
path=netcdf_path,
mode='w',
format='NETCDF4',
engine='netcdf4'
)

# Save the same params into a JSON sidecar
with open(metadata_path, 'w') as f:
json.dump(params, f, indent=2)
except Exception as e:
print(f"Failed to save Dataset to NetCDF cache: {e}")

def _convert_to_xarray(self, df: pd.DataFrame, columns, frequency: FrequencyType = FrequencyType.DAILY) -> xr.Dataset:
"""
Convert pandas DataFrame to xarray Dataset.
Expand Down
Binary file modified src/data/core/__pycache__/struct.cpython-312.pyc
Binary file not shown.
56 changes: 28 additions & 28 deletions src/data/core/struct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# src/data/data.py
# src/data/core/struct.py


import numpy as np
Expand Down Expand Up @@ -196,8 +196,8 @@ def from_table(
pre_dup_mask = data.duplicated(subset=dup_subset, keep=False)
pre_dup_data = data[pre_dup_mask].sort_values(dup_subset)

# print("==== DUPLICATE ROWS BEFORE SETTING INDEX ====") DEBUG
# print(pre_dup_data)
print("==== DUPLICATE ROWS BEFORE SETTING INDEX ====") #DEBUG
print(pre_dup_data)

counts = (
data
Expand All @@ -207,36 +207,36 @@ def from_table(
.sort_values(ascending=False)
)

# # Show only those with duplicates (DEBUG)
# counts_dup = counts[counts > 1]
# print("==== MULTIINDEX GROUPS THAT APPEAR MORE THAN ONCE ====")
# print(counts_dup)
# Show only those with duplicates (DEBUG)
counts_dup = counts[counts > 1]
print("==== MULTIINDEX GROUPS THAT APPEAR MORE THAN ONCE ====")
print(counts_dup)

# Set DataFrame index and reindex to include all possible combinations
data.set_index(index_names, inplace=True)


# ## OVERLOOK (DEBUG)
# print("The multi-index is not unique. Identifying duplicate index entries:")
# # Find duplicated index entries
# # duplicated_indices = data.index[data.index.duplicated(keep=False)]
# # print(duplicated_indices.unique())
# # print(data[data.index.isin(duplicated_indices)])
# duplicated_mask = data.index.duplicated(keep=False)
# duplicated_data = data[duplicated_mask]
# print("==== DUPLICATE ROWS ====")
# print(duplicated_data)
# i = 0
# for idx, group in duplicated_data.groupby(level=[0, 1, 2, 3]):
# print("MultiIndex:", idx)
# print(group)
# print("-----")
# i += 1
# if (i > 3):
# break

# quit(0)
# # OVERLOOK STOP
# OVERLOOK (DEBUG)
print("The multi-index is not unique. Identifying duplicate index entries:")
# Find duplicated index entries
duplicated_indices = data.index[data.index.duplicated(keep=False)]
print(duplicated_indices.unique())
print(data[data.index.isin(duplicated_indices)])
duplicated_mask = data.index.duplicated(keep=False)
duplicated_data = data[duplicated_mask]
print("==== DUPLICATE ROWS ====")
print(duplicated_data)
i = 0
for idx, group in duplicated_data.groupby(level=[0, 1, 2, 3]):
print("MultiIndex:", idx)
print(group)
print("-----")
i += 1
if (i > 3):
break

quit(0)
# OVERLOOK STOP

data = data.reindex(full_index)

Expand Down
Binary file modified src/data/loaders/open_bb/__pycache__/generic.cpython-312.pyc
Binary file not shown.
5 changes: 3 additions & 2 deletions src/data/loaders/open_bb/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def load_data(self, **config) -> xr.Dataset:
cache_path = self.get_cache_path(**cache_params)

# Try loading from cache
cached_ds = self.load_from_cache(cache_path, frequency=FrequencyType.DAILY)
cached_ds = self.load_from_cache(cache_path, request_params=cache_params)
if cached_ds is not None:
print("Loaded from NetCDF cache")
return cached_ds

# Resolve the function in OpenBB
Expand Down Expand Up @@ -81,6 +82,6 @@ def load_data(self, **config) -> xr.Dataset:
frequency=FrequencyType.DAILY,
)

# self.save_to_cache(ds, cache_path, frequency=FrequencyType.DAILY)
self.save_to_cache(ds, cache_path, cache_params)

return ds
Binary file modified src/data/loaders/wrds/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file modified src/data/loaders/wrds/__pycache__/compustat.cpython-312.pyc
Binary file not shown.
Binary file modified src/data/loaders/wrds/__pycache__/crsp.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 9ef1744

Please sign in to comment.