Skip to content

Commit

Permalink
Make bflist a polars dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Feb 13, 2024
1 parent 4243b14 commit abda351
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 25 deletions.
41 changes: 26 additions & 15 deletions artistools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,8 @@ def add_derived_metadata(metadata: dict[str, t.Any]) -> dict[str, t.Any]:

return add_derived_metadata(metadata)

print(f"No metadata found for: {filepath}")

return {}


Expand Down Expand Up @@ -944,24 +946,33 @@ def merge_pdf_files(pdf_files: list[str]) -> None:
print(f"Files merged and saved to {resultfilename}.pdf")


@lru_cache(maxsize=2)
def get_bflist(modelpath: Path | str) -> dict[int, tuple[int, int, int, int]]:
def get_bflist(modelpath: Path | str) -> pl.DataFrame:
"""Return a dict of bound-free transitions from bflist.out."""
compositiondata = get_composition_data(modelpath)
bflist = {}
bflistpath = firstexisting(["bflist.out", "bflist.dat"], folder=modelpath, tryzipped=True)
with zopen(bflistpath) as filein:
bflistcount = int(filein.readline())

for _ in range(bflistcount):
rowints = [int(x) for x in filein.readline().split()]
i, elementindex, ionindex, level = rowints[:4]
upperionlevel = rowints[4] if len(rowints) > 4 else -1
atomic_number = compositiondata.Z[elementindex]
ionstage = ionindex + compositiondata.lowermost_ionstage[elementindex]
bflist[i] = (atomic_number, ionstage, level, upperionlevel)

return bflist
print(f"Loading {bflistpath}")

dfboundfree = pl.read_csv(
bflistpath,
skip_rows=1,
separator=" ",
new_columns=["i", "elementindex", "ionindex", "lowerlevel", "upperionlevel"],
dtypes={
"i": pl.Int32,
"elementindex": pl.Int32,
"ionindex": pl.Int32,
"lowerlevel": pl.Int32,
"upperionlevel": pl.Int32,
},
)
dfboundfree = dfboundfree.with_columns(
atomic_number=pl.col("elementindex").map_elements(lambda elementindex: compositiondata["Z"][elementindex]),
ionstage=pl.col("ionindex")
+ pl.col("elementindex").map_elements(lambda elementindex: compositiondata["lowermost_ionstage"][elementindex]),
)
return dfboundfree.drop(["elementindex", "ionindex"]).select(
["atomic_number", "ionstage", "lowerlevel", "upperionlevel"]
)


linetuple = namedtuple("linetuple", "lambda_angstroms atomic_number ionstage upperlevelindex lowerlevelindex")
Expand Down
46 changes: 45 additions & 1 deletion artistools/packets/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def emtrue_timestep(packet) -> int:


def add_derived_columns_lazy(
dfpackets: pl.LazyFrame | pl.DataFrame, modelmeta: dict[str, t.Any], dfmodel: pd.DataFrame | pl.LazyFrame | None
dfpackets: pl.LazyFrame | pl.DataFrame,
modelmeta: dict[str, t.Any] | None = None,
dfmodel: pd.DataFrame | pl.LazyFrame | None = None,
) -> pl.LazyFrame:
"""Add columns to a packets DataFrame that are derived from the values that are stored in the packets files.
Expand All @@ -228,6 +230,9 @@ def add_derived_columns_lazy(
]
)

if modelmeta is None:
return dfpackets

if modelmeta["dimensions"] > 1:
t_model_s = modelmeta["t_model_init_days"] * 86400.0
vmax = modelmeta["vmax_cmps"]
Expand Down Expand Up @@ -322,6 +327,43 @@ def readfile_text(packetsfile: Path | str, modelpath: Path = Path()) -> pl.DataF
print(f"\nBad Gzip File: {packetsfile}")
raise

dtype_overrides = {
"absorption_freq": pl.Float32,
"absorption_type": pl.Int32,
"absorptiondirx": pl.Float32,
"absorptiondiry": pl.Float32,
"absorptiondirz": pl.Float32,
"e_cmf": pl.Float64,
"e_rf": pl.Float64,
"em_posx": pl.Float32,
"em_posy": pl.Float32,
"em_posz": pl.Float32,
"em_time": pl.Float32,
"emissiontype": pl.Int32,
"escape_time": pl.Float32,
"escape_type_id": pl.Int32,
"interactions": pl.Int32,
"last_event": pl.Int32,
"nscatterings": pl.Int32,
"nu_cmf": pl.Float32,
"nu_rf": pl.Float32,
"number": pl.Int32,
"originated_from_positron": pl.Int32,
"pellet_nucindex": pl.Int32,
"pol_dirx": pl.Float32,
"pol_diry": pl.Float32,
"pol_dirz": pl.Float32,
"scat_count": pl.Int32,
"stokes1": pl.Float32,
"stokes2": pl.Float32,
"stokes3": pl.Float32,
"t_decay": pl.Float32,
"true_emission_velocity": pl.Float32,
"trueem_time": pl.Float32,
"trueemissiontype": pl.Int32,
"type_id": pl.Int32,
}

try:
dfpackets = pl.read_csv(
fpackets,
Expand All @@ -330,6 +372,7 @@ def readfile_text(packetsfile: Path | str, modelpath: Path = Path()) -> pl.DataF
comment_prefix="#",
new_columns=column_names,
infer_schema_length=20000,
dtypes=dtype_overrides,
)

except Exception:
Expand Down Expand Up @@ -418,6 +461,7 @@ def convert_text_to_parquet(
def get_packetsfilepaths(
modelpath: str | Path, maxpacketfiles: int | None = None, printwarningsonly: bool = False
) -> list[Path]:
"""Get a list of Paths to parquet-formatted packets files, (which are generated from text files if needed)."""
nprocs = at.get_nprocs(modelpath)

searchfolders = [Path(modelpath, "packets"), Path(modelpath)]
Expand Down
21 changes: 12 additions & 9 deletions artistools/spectra/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def get_flux_contributions_from_packets(
filterfunc: t.Callable[[np.ndarray], np.ndarray] | None = None,
groupby: t.Literal["ion", "line", "upperterm", "terms"] | None = "ion",
modelgridindex: int | None = None,
use_escapetime: bool = False,
use_time: t.Literal["arrival", "emission", "escape"] = "arrival",
use_lastemissiontype: bool = True,
emissionvelocitycut: float | None = None,
directionbin: int | None = None,
Expand Down Expand Up @@ -848,13 +848,11 @@ def get_emprocesslabel(emtype: int) -> str:
return "free-free"

bfindex = -emtype - 1
if bfindex in bflist:
(atomic_number, ionstage, level) = bflist[bfindex][:3]
if groupby == "line":
return f"{at.get_ionstring(atomic_number, ionstage)} bound-free {level}"
return f"{at.get_ionstring(atomic_number, ionstage)} bound-free"
atomic_number, ionstage = bflist.item(bfindex, "atomic_number"), bflist.item(bfindex, "ionstage")

return f"? bound-free (bfindex={bfindex})"
if groupby == "line":
return f"{at.get_ionstring(atomic_number, ionstage)} bound-free {bflist.item(bfindex, "lowerlevel")}-{bflist.item(bfindex, "upperionlevel")}"
return f"{at.get_ionstring(atomic_number, ionstage)} bound-free"

def get_absprocesslabel(abstype: int) -> str:
if abstype >= 0:
Expand All @@ -875,15 +873,19 @@ def get_absprocesslabel(abstype: int) -> str:
nprocs_read, lzdfpackets = at.packets.get_packets_pl(
modelpath, maxpacketfiles=maxpacketfiles, packet_type="TYPE_ESCAPE", escape_type="TYPE_RPKT"
)
if emissionvelocitycut is not None:
lzdfpackets = at.packets.add_derived_columns_lazy(lzdfpackets)
lzdfpackets = lzdfpackets.filter(pl.col("emission_velocity") > emissionvelocitycut)

lzdfpackets = lzdfpackets.filter(pl.col("t_arrive_d").is_between(float(timelowdays), float(timehighdays)))

cols = {"t_arrive_d", "e_rf"}
cols = {"e_rf"}
cols.add({"arrival": "t_arrive_d", "emission": "em_time", "escape": "excape_time"}[use_time])

if getemission:
cols |= {"emissiontype_str", "nu_rf"}
emtypes = lzdfpackets.select(emtypecolumn).collect().get_column(emtypecolumn).unique().sort()

emtypes = lzdfpackets.select(emtypecolumn).collect().get_column(emtypecolumn).unique().sort()
lzdfpackets = lzdfpackets.join(
pl.DataFrame({emtypecolumn: emtypes, "emissiontype_str": emtypes.map_elements(get_emprocesslabel)}).lazy(),
on=emtypecolumn,
Expand Down Expand Up @@ -924,6 +926,7 @@ def get_absprocesslabel(abstype: int) -> str:
timehighdays=timehighdays,
lambda_min=lambda_min,
lambda_max=lambda_max,
use_time=use_time,
delta_lambda=delta_lambda,
fnufilterfunc=filterfunc,
nprocs_read_dfpackets=(nprocs_read, emissiongroups[groupname]),
Expand Down

0 comments on commit abda351

Please sign in to comment.