Skip to content

Commit

Permalink
abundances stage wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Jan 10, 2025
1 parent a93599b commit 72d1f22
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 212 deletions.
106 changes: 88 additions & 18 deletions src/astra/models/aspcap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,74 @@
DateTimeField,
BooleanField,
)

from astra.models.pipeline import PipelineOutputModel
from astra.models.ferre import FerreCoarse, FerreStellarParameters, FerreChemicalAbundances
from astra.models.source import Source
from astra.models.spectrum import Spectrum
from astra.glossary import Glossary
from playhouse.hybrid import hybrid_property
from functools import cached_property
from astra.pipelines.ferre.utils import (get_apogee_pixel_mask, parse_ferre_spectrum_name)

APOGEE_FERRE_MASK = get_apogee_pixel_mask()

"""
@cached_property
def ferre_flux(self):
return self._get_pixel_array("params/flux.input")
@cached_property
def ferre_e_flux(self):
return self._get_pixel_array("params/e_flux.input")
#@cached_property
#def model_flux(self):
# return self._get_pixel_array("params/model_flux.output")
@cached_property
def rectified_model_flux(self):
return self._get_pixel_array("params/rectified_model_flux.output")
@cached_property
def rectified_flux(self):
return self._get_pixel_array("params/rectified_flux.output")
@cached_property
def e_rectified_flux(self):
continuum = self.ferre_flux / self.rectified_flux
return self.ferre_e_flux / continuum
"""

class StellarParameterPixelAccessor(BasePixelArrayAccessor):
class ASPCAPPixelArrayAccessor(BasePixelArrayAccessor):

def __get__(self, instance, instance_type=None):
if instance is not None:
try:
return instance.__pixel_data__[self.name]
except (AttributeError, KeyError):
# Load them all.
if not hasattr(instance, "__pixel_data__"):
instance.__pixel_data__ = {}

upstream = FerreStellarParameters.get(instance.stellar_parameters_task_pk)
continuum = upstream.unmask(
(upstream.rectified_model_flux/upstream.model_flux)
/ (upstream.rectified_flux/upstream.ferre_flux)
)

instance.__pixel_data__.setdefault("continuum", continuum)
instance.__pixel_data__.setdefault("model_flux", upstream.unmask(upstream.model_flux))

return instance.__pixel_data__[self.name]

# Stellar parameter case first, since we have to load a bunch of stuff.
if self.name not in instance.__pixel_data__:
if self.name in ("model_flux", "continuum"):
rectified_model_flux = instance._get_output_pixel_array("params", "rectified_model_flux.output")
model_flux = instance._get_output_pixel_array("params", "model_flux.output")
rectified_flux = instance._get_output_pixel_array("params", "rectified_flux.output")
ferre_flux = instance._get_input_pixel_array("params", "flux.input")

continuum = instance._unmask_pixel_array(
(rectified_model_flux/model_flux) / (rectified_flux/ferre_flux)
)
instance.__pixel_data__.setdefault("continuum", continuum)
instance.__pixel_data__.setdefault("model_flux", instance._unmask_pixel_array(model_flux))

else:
# Chemical abundance pixel array.
x_h = self.name[len("model_flux_"):]
#isntance._get_output_pixel_array("abundances", "")
raise NotImplementedError

return instance.__pixel_data__[self.name]
return self.field


Expand Down Expand Up @@ -84,6 +122,7 @@ class ASPCAP(PipelineOutputModel):

""" APOGEE Stellar Parameter and Chemical Abundances Pipeline (ASPCAP) """


#> Spectral Data
wavelength = PixelArray(
accessor_class=LogLambdaArrayAccessor,
Expand All @@ -94,11 +133,11 @@ class ASPCAP(PipelineOutputModel):
),
)
model_flux = PixelArray(
accessor_class=StellarParameterPixelAccessor,
accessor_class=ASPCAPPixelArrayAccessor,
help_text="Model flux at optimized stellar parameters"
)
continuum = PixelArray(
accessor_class=StellarParameterPixelAccessor,
accessor_class=ASPCAPPixelArrayAccessor,
help_text="Continuum"
)

Expand Down Expand Up @@ -624,6 +663,9 @@ def flag_bad(self):
coarse_rchi2 = FloatField(null=True, help_text=Glossary.coarse_rchi2)
coarse_penalized_rchi2 = FloatField(null=True, help_text="Penalized reduced chi-squared for coarse grid")

pwd = TextField(null=True, help_text="Working directory")
ferre_index = IntegerField(null=True, help_text="Index of the FERRE run")

"""
#> Task Primary Keys
stellar_parameters_task_pk = ForeignKeyField(FerreStellarParameters, unique=True, null=True, lazy_load=False, help_text="Task primary key for stellar parameters")
Expand Down Expand Up @@ -730,6 +772,34 @@ def flag_bad(self):
raw_e_v_h = FloatField(null=True, help_text=Glossary.raw_e_v_h)


def _unmask_pixel_array(self, array, fill_value=np.nan):
unmasked_array = fill_value * np.ones(APOGEE_FERRE_MASK.shape)
unmasked_array[APOGEE_FERRE_MASK] = array
return unmasked_array


def _get_pixel_array_kwds(self, stage, name, **kwargs):
kwds = dict(
fname=f"{self.pwd}/{stage}/{self.short_grid_name}/{name}",
skiprows=int(self.ferre_index),
max_rows=1,
)
return kwds

def _get_input_pixel_array(self, stage, name):
return np.loadtxt(**self._get_pixel_array_kwds(stage, name))

def _get_output_pixel_array(self, stage, name, P=7514):
kwds = self._get_pixel_array_kwds(stage, name)
name, = np.atleast_1d(np.loadtxt(usecols=(0, ), dtype=str, **kwds))
array = np.loadtxt(usecols=range(1, 1+P), **kwds)
meta = parse_ferre_spectrum_name(name)
assert int(meta["source_pk"]) == self.source_pk
assert int(meta["spectrum_pk"]) == self.spectrum_pk
assert int(meta["index"]) == self.ferre_index
return array




def apply_noise_model():
Expand Down
25 changes: 12 additions & 13 deletions src/astra/models/ferre.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ def unmask(self, array, fill_value=np.nan):
def _get_input_pixel_array(self, basename):
return np.loadtxt(
fname=f"{self.pwd}/{basename}",
skiprows=int(self.ferre_input_index),
skiprows=int(self.ferre_index),
max_rows=1,
)


def _get_output_pixel_array(self, basename, P=7514):

#assert self.ferre_input_index >= 0
#assert self.ferre_index >= 0

kwds = dict(
fname=f"{self.pwd}/{basename}",
skiprows=int(self.ferre_output_index),
skiprows=int(self.ferre_index),
max_rows=1,
)
'''
Expand All @@ -83,17 +83,17 @@ def _get_output_pixel_array(self, basename, P=7514):
if (
(int(meta["source_pk"]) != self.source_pk)
or (int(meta["spectrum_pk"]) != self.spectrum_pk)
or (int(meta["index"]) != self.ferre_input_index)
or (int(meta["index"]) != self.ferre_index)
):
raise a
except:
del kwds["skiprows"]
del kwds["max_rows"]
name = get_ferre_spectrum_name(self.ferre_input_index, self.source_pk, self.spectrum_pk, self.initial_flags, self.upstream_id)
name = get_ferre_spectrum_name(self.ferre_index, self.source_pk, self.spectrum_pk, self.initial_flags, self.upstream_id)
index = list(np.loadtxt(usecols=(0, ), dtype=str, **kwds)).index(name)
self.ferre_output_index = index
self.ferre_index = index
self.save()
print("saved!")
kwds["skiprows"] = index
Expand All @@ -108,7 +108,7 @@ def _get_output_pixel_array(self, basename, P=7514):
meta = parse_ferre_spectrum_name(name)
assert int(meta["source_pk"]) == self.source_pk
assert int(meta["spectrum_pk"]) == self.spectrum_pk
assert int(meta["index"]) == self.ferre_input_index
assert int(meta["index"]) == self.ferre_index

return array

Expand Down Expand Up @@ -180,8 +180,7 @@ class FerreCoarse(PipelineOutputModel, FerreOutputMixin):

#> FERRE Access Fields
ferre_name = TextField(default="")
ferre_input_index = IntegerField(default=-1)
ferre_output_index = IntegerField(default=-1)
ferre_index = IntegerField(default=-1)
ferre_n_obj = IntegerField(default=-1)

#> Summary Statistics
Expand Down Expand Up @@ -311,8 +310,8 @@ class FerreStellarParameters(PipelineOutputModel, FerreOutputMixin):
# TODO: flag definitions for each dimension (DRY)
#> FERRE Access Fields
ferre_name = TextField(default="")
ferre_input_index = IntegerField(default=-1)
ferre_output_index = IntegerField(default=-1)
ferre_index = IntegerField(default=-1)
ferre_index = IntegerField(default=-1)
ferre_n_obj = IntegerField(default=-1)

#> Summary Statistics
Expand Down Expand Up @@ -450,8 +449,8 @@ def ferre_e_flux(self):
# TODO: flag definitions for each dimension (DRY)
#> FERRE Access Fields
ferre_name = TextField(default="")
ferre_input_index = IntegerField(default=-1)
ferre_output_index = IntegerField(default=-1)
ferre_index = IntegerField(default=-1)
ferre_index = IntegerField(default=-1)
ferre_n_obj = IntegerField(default=-1)

#> Summary Statistics
Expand Down
Loading

0 comments on commit 72d1f22

Please sign in to comment.