From 187b4c0615a03920c2659190b7a818d057ba12ce Mon Sep 17 00:00:00 2001 From: zhuwq Date: Wed, 6 Nov 2024 22:15:52 -0800 Subject: [PATCH] add subdir_level --- phasenet/predict.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/phasenet/predict.py b/phasenet/predict.py index cf61d69..7365d25 100755 --- a/phasenet/predict.py +++ b/phasenet/predict.py @@ -13,6 +13,7 @@ import pandas as pd import tensorflow as tf from data_reader import DataReader_mseed_array, DataReader_pred +from model import ModelConfig, UNet from postprocess import ( extract_amplitude, extract_picks, @@ -23,8 +24,6 @@ from tqdm import tqdm from visulization import plot_waveform -from model import ModelConfig, UNet - tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -60,6 +59,8 @@ def read_args(): parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter") parser.add_argument("--response_xml", default=None, type=str, help="response xml file") parser.add_argument("--sampling_rate", default=100, type=float, help="sampling rate") + + parser.add_argument("--subdir_level", default=2, type=int, help="subdirectory level") args = parser.parse_args() return args @@ -158,7 +159,8 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): if len(fname_batch) == 1: # ### FIX: Hard code for NCEDC and SCEDC tmp = fname_batch[0].decode().split(",")[0].split("/") - subdir = "/".join(tmp[-1-3:-1]) + # subdir = "/".join(tmp[-1 - 3 : -1]) + subdir = "/".join(tmp[-1 - args.subdir_level : -1]) fname = tmp[-1].rstrip("\n").rstrip(".mseed").rstrip(".ms") + ".csv" # csv_name = f"quakeflow_catalog/NC/phasenet/{subdir}/{fname}" # csv_name = f"quakeflow_catalog/SC/phasenet/{subdir}/{fname}"