Skip to content

Commit

Permalink
Changes to store interfaces and implementations (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs authored Oct 5, 2023
1 parent e80326c commit 88dc14e
Show file tree
Hide file tree
Showing 16 changed files with 526 additions and 336 deletions.
55 changes: 41 additions & 14 deletions src/noisepy/seis/asdfstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Generic, List, Optional, Set, Tuple, TypeVar

import numpy as np
import obspy
Expand Down Expand Up @@ -54,16 +54,20 @@ def __init__(
self.parse_filename = parse_filename

def __getitem__(self, key: T) -> pyasdf.ASDFDataSet:
return self._get_dataset(key, self.mode)

def _get_dataset(self, key: T, mode: str) -> pyasdf.ASDFDataSet:
file_name = self.get_filename(key)
file_path = os.path.join(self.directory, file_name)
return _get_dataset(file_path, self.mode)
return _get_dataset(file_path, mode)

def get_keys(self) -> List[T]:
h5files = sorted(glob.glob(os.path.join(self.directory, "**/*.h5"), recursive=True))
return list(map(self.parse_filename, h5files))

def contains(self, key: T, data_type: str, path: str = None):
ccf_ds = self[key]
# contains is always a read
ccf_ds = self._get_dataset(key, "r")

if not ccf_ds:
return False
Expand Down Expand Up @@ -128,7 +132,7 @@ def __init__(self, directory: str, mode: str = "a") -> None:
self.datasets = ASDFDirectory(directory, mode, _filename_from_timespan, parse_timespan)

# CrossCorrelationDataStore implementation
def contains(self, timespan: DateTimeRange, src: Station, rec: Station) -> bool:
def contains(self, src: Station, rec: Station, timespan: DateTimeRange) -> bool:
station_pair = self._get_station_pair(src, rec)
contains = self.datasets.contains(timespan, station_pair)
if contains:
Expand All @@ -149,19 +153,23 @@ def append(
channels = self._get_channel_pair(cc.src, cc.rec)
self.datasets.add_aux_data(timespan, cc.parameters, station_pair, channels, cc.data)

def get_timespans(self) -> List[DateTimeRange]:
return self.datasets.get_keys()
def get_timespans(self, src: Station, rec: Station) -> List[DateTimeRange]:
timespans = {}
pair_key = self._get_station_pair(src, rec)

def visit(pairs, ts):
if pair_key in pairs:
timespans[str(ts)] = ts

self._visit_pairs(visit)
return sorted(timespans.values(), key=lambda t: str(t))

def get_station_pairs(self) -> List[Tuple[Station, Station]]:
timespans = self.get_timespans()
pairs_all = set()
for timespan in timespans:
with self.datasets[timespan] as ccf_ds:
data = ccf_ds.auxiliary_data.list()
pairs_all.update(parse_station_pair(p) for p in data if p != PROGRESS_DATATYPE)
self._visit_pairs(lambda pairs, _: pairs_all.update((parse_station_pair(p) for p in pairs)))
return list(pairs_all)

def read_correlations(self, timespan: DateTimeRange, src_sta: Station, rec_sta: Station) -> List[CrossCorrelation]:
def read(self, timespan: DateTimeRange, src_sta: Station, rec_sta: Station) -> List[CrossCorrelation]:
with self.datasets[timespan] as ccf_ds:
dtype = self._get_station_pair(src_sta, rec_sta)
if dtype not in ccf_ds.auxiliary_data:
Expand All @@ -175,20 +183,39 @@ def read_correlations(self, timespan: DateTimeRange, src_sta: Station, rec_sta:
ccs.append(CrossCorrelation(src_ch, rec_ch, stream.parameters, stream.data[:]))
return ccs

def _visit_pairs(self, visitor: Callable[[Set[Tuple[str, str]], DateTimeRange], None]):
all_timespans = self.datasets.get_keys()
for timespan in all_timespans:
with self.datasets[timespan] as ccf_ds:
data = ccf_ds.auxiliary_data.list()
pairs = {p for p in data if p != PROGRESS_DATATYPE}
visitor(pairs, timespan)

def _get_channel_pair(self, src_chan: ChannelType, rec_chan: ChannelType) -> str:
return f"{src_chan}_{rec_chan}"

def _get_station_pair(self, src_sta: Station, rec_sta: Station) -> str:
return f"{src_sta}_{rec_sta}"


class ASDFStackStore(StackStore):
def __init__(self, directory: str, mode: str = "a"):
super().__init__()
self.datasets = ASDFDirectory(directory, mode, _filename_from_stations, _parse_station_pair_h5file)

def append(self, src: Station, rec: Station, stacks: List[Stack]):
# TODO: Do we want to support storing stacks from different timespans in the same store?
def append(self, timespan: DateTimeRange, src: Station, rec: Station, stacks: List[Stack]):
for stack in stacks:
self.datasets.add_aux_data((src, rec), stack.parameters, stack.name, stack.component, stack.data)

def get_station_pairs(self) -> List[Tuple[Station, Station]]:
return self.datasets.get_keys()

def read_stacks(self, src: Station, rec: Station) -> List[Stack]:
def get_timespans(self, src: Station, rec: Station) -> List[DateTimeRange]:
# TODO: Do we want to support storing stacks from different timespans in the same store?
return []

def read(self, timespan: DateTimeRange, src: Station, rec: Station) -> List[Stack]:
stacks = []
with self.datasets[(src, rec)] as ds:
for name in ds.auxiliary_data.list():
Expand Down
10 changes: 5 additions & 5 deletions src/noisepy/seis/correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def cc_timespan(
pair_filter: Callable[[Channel, Channel], bool] = lambda src, rec: True,
) -> bool:
errors = False
executor = ThreadPoolExecutor(max_workers=12)
executor = ThreadPoolExecutor()
tlog = TimeLogger(logger, logging.INFO, prefix="CC Main")
"""
LOADING NOISE DATA AND DO FFT
Expand All @@ -134,7 +134,7 @@ def cc_timespan(
station_pairs = list(create_pairs(pair_filter, all_channels, fft_params.acorr_only).keys())
# Check for stations that are already done, do this in parallel
logger.info(f"Checking for stations already done: {len(station_pairs)} pairs")
station_pair_dones = list(executor.map(lambda p: cc_store.contains(ts, p[0], p[1]), station_pairs))
station_pair_dones = list(executor.map(lambda p: cc_store.contains(p[0], p[1], ts), station_pairs))

missing_pairs = [pair for pair, done in zip(station_pairs, station_pair_dones) if not done]
# get a set of unique stations from the list of pairs
Expand Down Expand Up @@ -260,10 +260,10 @@ def create_pairs(
if src_chan.station != rec_chan.station:
continue
if ffts and iiS not in ffts:
logger.warning(f"No FFT data available for channel '{src_chan}', skipped")
logger.warning(f"No FFT data available for src channel '{src_chan}', skipped")
continue
if ffts and iiR not in ffts:
logger.warning(f"No FFT data available for channel '{rec_chan}', skipped")
logger.warning(f"No FFT data available for rec channel '{rec_chan}', skipped")
continue

station_pairs[(src_chan.station, rec_chan.station)].append((iiS, iiR))
Expand All @@ -285,7 +285,7 @@ def stations_cross_correlation(
tlog = TimeLogger(logger, logging.DEBUG)
datas = []
try:
if cc_store.contains(ts, src, rec):
if cc_store.contains(src, rec, ts):
logger.info(f"Skipping {src}_{rec} for {ts} since it's already done")
return True, None

Expand Down
20 changes: 18 additions & 2 deletions src/noisepy/seis/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic.functional_validators import model_validator
from pydantic_yaml import parse_yaml_raw_as, to_yaml_str

from noisepy.seis.utils import get_filesystem
from noisepy.seis.utils import get_filesystem, remove_nan_rows, remove_nans

INVALID_COORD = -sys.float_info.max

Expand Down Expand Up @@ -82,6 +82,13 @@ def __init__(
self.elevation = elevation
self.location = location

def parse(sta: str) -> Optional[Station]:
# Parse from: CI.ARV_CI.BAK
parts = sta.split(".")
if len(parts) != 2:
return None
return Station(parts[0], parts[1])

def valid(self) -> bool:
return min(self.lat, self.lon, self.elevation) > INVALID_COORD

Expand Down Expand Up @@ -297,7 +304,7 @@ def __init__(self, params: Dict[str, Any], data: np.ndarray):
def get_metadata(self) -> Tuple:
pass

def pack(datas: List[AnnotatedData]) -> AnnotatedData:
def pack(datas: List[AnnotatedData]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
if len(datas) == 0:
raise ValueError("Cannot pack empty list of data")
# Some arrays may have different lengths, so pad them with NaNs for stacking
Expand Down Expand Up @@ -343,6 +350,12 @@ def __init__(self, src: ChannelType, rec: ChannelType, params: Dict[str, Any], d
def get_metadata(self) -> Tuple:
return (self.src.name, self.src.location, self.rec.name, self.rec.location, self.parameters)

def load_instances(tuples: List[Tuple[np.ndarray, Dict[str, Any]]]) -> List[CrossCorrelation]:
return [
CrossCorrelation(ChannelType(src, src_loc), ChannelType(rec, rec_loc), params, remove_nan_rows(data))
for data, (src, src_loc, rec, rec_loc, params) in tuples
]


class Stack(AnnotatedData):
component: str # e.g. "EE", "EN", ...
Expand All @@ -356,6 +369,9 @@ def __init__(self, component: str, name: str, params: Dict[str, Any], data: np.n
def get_metadata(self) -> Tuple:
return (self.component, self.name, self.parameters)

def load_instances(tuples: List[Tuple[np.ndarray, Dict[str, Any]]]) -> List[Stack]:
return [Stack(comp, name, params, remove_nans(data)) for data, (comp, name, params) in tuples]


def to_json_types(params: Dict[str, Any]) -> Dict[str, Any]:
return {k: _to_json_type(v) for k, v in params.items()}
Expand Down
Loading

0 comments on commit 88dc14e

Please sign in to comment.