Skip to content

Commit

Permalink
add new CLI args
Browse files Browse the repository at this point in the history
  • Loading branch information
jrob93 committed Nov 7, 2024
1 parent e980c08 commit fd5bb4a
Show file tree
Hide file tree
Showing 13 changed files with 902 additions and 254 deletions.
72 changes: 52 additions & 20 deletions src/adler/adler_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import argparse
import astropy.units as u
from astropy.time import Time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -103,7 +104,8 @@ def runAdler(cli_args):

# initial simple phase curve filter model with fixed G12
pc = PhaseCurve(
H=sso.H * u.mag,
# H=sso.H * u.mag,
H=sso.H,
phase_parameter_1=0.62,
model_name="HG12_Pen16",
)
Expand All @@ -118,14 +120,24 @@ def runAdler(cli_args):

# do a HG12_Pen16 fit to the past data
pc_fit = pc.FitModel(
np.array(df_obs["phaseAngle"]) * u.deg,
np.array(df_obs["reduced_mag"]) * u.mag,
np.array(df_obs["magErr"]) * u.mag,
# np.array(df_obs["phaseAngle"]) * u.deg,
# np.array(df_obs["reduced_mag"]) * u.mag,
# np.array(df_obs["magErr"]) * u.mag,
np.radians(np.array(df_obs["phaseAngle"])),
np.array(df_obs["reduced_mag"]),
np.array(df_obs["magErr"]),
)
pc_fit = pc.InitModelSbpy(pc_fit)

# Store the fitted values in an AdlerData object
adler_data.populate_phase_parameters(filt, **pc_fit.__dict__)
# Store the fitted values, and metadata, in an AdlerData object
ad_params = pc_fit.__dict__
ad_params["phaseAngle_min"] = np.amin(df_obs["phaseAngle"]) # * u.deg
ad_params["phaseAngle_range"] = np.ptp(df_obs["phaseAngle"]) # * u.deg
ad_params["arc"] = np.ptp(df_obs["midPointMjdTai"]) # * u.d
ad_params["nobs"] = len(df_obs)
ad_params["modelFitMjd"] = Time.now().mjd
# adler_data.populate_phase_parameters(filt, **pc_fit.__dict__)
adler_data.populate_phase_parameters(filt, **ad_params)

# add to plot
ax1 = fig.axes[0]
Expand All @@ -134,13 +146,18 @@ def runAdler(cli_args):
alpha = np.linspace(0, np.amax(obs.phaseAngle)) * u.deg
ax1.plot(
alpha.value,
pc_fit.ReducedMag(alpha).value,
label="{}, H={:.2f}, G12={:.2f}".format(filt, pc_fit.H.value, pc_fit.phase_parameter_1),
# pc_fit.ReducedMag(alpha).value,
# label="{}, H={:.2f}, G12={:.2f}".format(filt, pc_fit.H.value, pc_fit.phase_parameter_1),
pc_fit.ReducedMag(alpha),
label="{}, H={:.2f}, G12={:.2f}".format(filt, pc_fit.H, pc_fit.phase_parameter_1),
)

# TODO: save the figures if an outpath is provided
ax1.legend()
if cli_args.outpath:

# TODO: Use a CLI arg flag to open figure interactively instead of saving?
if cli_args.plot_show:
plt.show()
# Save figures at the outpath location
else:
fig_file = "{}/phase_curve_{}_{}.png".format(
cli_args.outpath, cli_args.ssObjectId, int(np.amax(df_obs["midPointMjdTai"]))
)
Expand All @@ -149,11 +166,15 @@ def runAdler(cli_args):
logger.info(msg)
fig = plot_errorbar(planetoid, fig=fig, filename=fig_file) # TODO: add titles with filter name?
plt.close()
else:
plt.show()

# TODO: output adler values to a database
# Output adler values to a database if a db_name is provided
print(adler_data.__dict__)
if cli_args.db_name:
adler_db = "{}/{}".format(cli_args.outpath, cli_args.db_name)
msg = "write to {}".format(adler_db)
print(msg)
logger.info(msg)
adler_data.write_row_to_database(adler_db)

# analyse colours for the filters provided
logger.info("Calculate colours: {}".format(cli_args.colour_list))
Expand Down Expand Up @@ -183,10 +204,10 @@ def runAdler(cli_args):

# determine the filt_obs - filt_ref colour
# generate a plot
if cli_args.outpath:
plot_dir = cli_args.outpath
else:
if cli_args.plot_show:
plot_dir = None
else:
plot_dir = cli_args.outpath

col_dict = col_obs_ref(
planetoid,
Expand All @@ -198,6 +219,7 @@ def runAdler(cli_args):
# x1 = x1,
# x2 = x2,
plot_dir=plot_dir,
plot_show=cli_args.plot_show,
)

print(col_dict)
Expand Down Expand Up @@ -253,17 +275,27 @@ def main():
optional_group.add_argument(
"-n",
"--db_name",
help="Stem filename of output database. If this doesn't exist, it will be created. Default: adler_out.",
# help="Stem filename of output database. If this doesn't exist, it will be created. Default: adler_out.",
# type=str,
# default="adler_out",
help="Optional filename of output database, used to store Adler results in a db if provided.",
type=str,
default="adler_out",
default=None,
)
optional_group.add_argument(
"-i",
"--sql_filename",
help="Optional input path location of a sql database file containing observations",
help="Optional input path location of a sql database file containing observations.",
type=str,
default=None,
)
# TODO: add flag argument to display plots instead of saving them
optional_group.add_argument(
"-p",
"--plot_show",
help="Optional flag to display plots interactively instead of saving to file.",
action="store_true",
)

args = parser.parse_args()

Expand Down
4 changes: 3 additions & 1 deletion src/adler/objectdata/AdlerData.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from datetime import datetime, timezone


FILTER_DEPENDENT_KEYS = ["phaseAngle_min", "phaseAngle_range", "nobs", "arc", "modelFitMjd"]
FILTER_DEPENDENT_KEYS = ["phaseAngle_min", "phaseAngle_range", "nobs", "arc"]
MODEL_DEPENDENT_KEYS = [
"H",
"H_err",
"phase_parameter_1",
"phase_parameter_1_err",
"phase_parameter_2",
"phase_parameter_2_err",
"modelFitMjd",
]
ALL_FILTER_LIST = ["u", "g", "r", "i", "z", "y"]

Expand Down Expand Up @@ -522,6 +523,7 @@ class PhaseModelDependentAdler:
phase_parameter_1_err: float = np.nan
phase_parameter_2: float = np.nan
phase_parameter_2_err: float = np.nan
modelFitMjd: float = np.nan


class PhaseParameterOutput:
Expand Down
37 changes: 22 additions & 15 deletions src/adler/science/Colour.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def col_obs_ref(
x1=None,
x2=None,
plot_dir=None,
plot_show=False,
):
"""A function to calculate the colour of an Adler planetoid object.
An observation in a given filter (filt_obs) is compared to previous observation in a reference filter (filt_ref).
Expand Down Expand Up @@ -99,7 +100,8 @@ def col_obs_ref(
x_obs = df_obs.iloc[i][x_col]
y_obs = df_obs.iloc[i][y_col]
yerr_obs = df_obs.iloc[i][yerr_col]
obsId = df_obs.iloc[i][obsId_col]
# obsId = df_obs.iloc[i][obsId_col] # NB: for some reason iloc here doesn't preserve the int64 dtype of obsId_col...?
obsId = df_obs[obsId_col].iloc[i]

# select observations in the reference filter from before the new obs
ref_mask = df_obs_ref[x_col] < x_obs
Expand All @@ -112,11 +114,11 @@ def col_obs_ref(

# select only the N_ref ref obs for comparison
_df_obs_ref = df_obs_ref[ref_mask].iloc[-_N_ref:]
print(len(_df_obs_ref))
print(np.array(_df_obs_ref[x_col]))
# print(len(_df_obs_ref))
# print(np.array(_df_obs_ref[x_col]))
if len(_df_obs_ref) == 0:
print("no reference observations") # TODO: add proper error handling and logging here
return df_obs
# print("insufficient reference observations") # TODO: add proper error handling and logging here
return

# determine reference observation values
y_ref = np.mean(_df_obs_ref[y_col]) # TODO: add option to choose statistic, e.g. mean or median?
Expand All @@ -127,7 +129,9 @@ def col_obs_ref(

# Create the colour dict
col_dict = {}
col_dict[obsId_col] = obsId
col_dict[obsId_col] = np.int64(
obsId
) # store id as an int to make sure it doesn't get stored as float e notation!
col_dict[x_col] = x_obs
col_dict[colour] = y_obs - y_ref
col_dict[delta_t_col] = x_obs - x2_ref
Expand All @@ -141,7 +145,7 @@ def col_obs_ref(
# need to test error case where there are no r filter obs yet

# TODO: add a plotting option?
if plot_dir:
if plot_dir or plot_show:
fig = plt.figure()
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0, 0])
Expand All @@ -160,13 +164,16 @@ def col_obs_ref(
ax1.legend()
ax1.invert_yaxis()

fname = "{}/colour_plot_{}_{}-{}_{}.png".format(
plot_dir, planetoid.ssObjectId, filt_obs, filt_ref, int(x_obs)
)
print("Save figure: {}".format(fname))
plt.savefig(fname, facecolor="w", transparent=True, bbox_inches="tight")

# plt.show() # TODO: add option to display figure, or to return the fig object?
plt.close()
if plot_dir:
fname = "{}/colour_plot_{}_{}-{}_{}.png".format(
plot_dir, planetoid.ssObjectId, filt_obs, filt_ref, int(x_obs)
)
print("Save figure: {}".format(fname))
plt.savefig(fname, facecolor="w", transparent=True, bbox_inches="tight")

if plot_show:
plt.show() # TODO: add option to display figure, or to return the fig object?
else:
plt.close()

return col_dict
18 changes: 18 additions & 0 deletions src/adler/utilities/AdlerCLIArguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self, args):
self.outpath = args.outpath
self.db_name = args.db_name
self.sql_filename = args.sql_filename
self.plot_show = args.plot_show
self.phase_model = args.phase_model

self.validate_arguments()

Expand All @@ -47,6 +49,9 @@ def validate_arguments(self):
if self.colour_list:
self._validate_colour_list()

if self.phase_model:
self._validate_phase_model()

def _validate_filter_list(self):
"""Validation checks for the filter_list command-line argument."""
expected_filters = ["u", "g", "r", "i", "z", "y"]
Expand Down Expand Up @@ -149,3 +154,16 @@ def _validate_sql_filename(self):
raise ValueError(
"The file supplied for the command-line argument --sql_filename cannot be found."
)

def _validate_phase_model(self):
"""Validation checks for the phase_model command-line argument."""
expected_models = ["HG", "HG1G2", "HG12", "HG12_Pen16", "LinearPhaseFunc"]
err_msg_model = (
"Unexpected model in --phase_model command-line arguments. Please select from {}".format(
expected_models
)
)

if self.phase_model not in expected_models:
logging.error(err_msg_model)
raise ValueError(err_msg_model)
7 changes: 5 additions & 2 deletions src/adler/utilities/science_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def zero_func(x, axis=None):
return 0


def sigma_clip(data_res, kwargs={"maxiters": 1, "cenfunc": zero_func}):
# def sigma_clip(data_res, kwargs={"sigma":3, "maxiters": 1, "cenfunc": zero_func}):
def sigma_clip(data_res, **kwargs):
"""Wrapper function for astropy.stats.sigma_clip, here we define the default centre of the data (the data - model residuals) to be zero
Parameters
Expand Down Expand Up @@ -206,7 +207,9 @@ def get_df_obs_filt(planetoid, filt, x_col="midPointMjdTai", x1=None, x2=None, c
if "AbsMag" in col_list:
# calculate the model absolute magnitude
# TODO: add robustness to the units, phaseAngle and reduced_mag must match pc_model
df_obs["AbsMag"] = pc_model.AbsMag(obs.phaseAngle * u.deg, obs.reduced_mag * u.mag).value
# For now we must assume that there are no units, and that degrees have been passed...
# df_obs["AbsMag"] = pc_model.AbsMag(obs.phaseAngle * u.deg, obs.reduced_mag * u.mag).value
df_obs["AbsMag"] = pc_model.AbsMag(np.radians(obs.phaseAngle), obs.reduced_mag)

# select only the required columns
df_obs = df_obs[col_list]
Expand Down
Loading

0 comments on commit fd5bb4a

Please sign in to comment.