diff --git a/bin/astra b/bin/astra index 80b323a..a3868e1 100755 --- a/bin/astra +++ b/bin/astra @@ -1,861 +1,416 @@ #!/usr/bin/env python3 -import click - -# Common options. -@click.group() -@click.option("-v", "verbose", default=False, is_flag=True, help="verbose mode") -@click.pass_context -def cli(context, verbose): - context.ensure_object(dict) - context.obj["verbose"] = verbose - # Overwrite settings in ~/.astra/astra.yml - # from astra import log - # log.set_level(10 if verbose else 20) - - -@cli.command() -@click.option("--drop-tables", is_flag=True) -@click.option("--delay", default=10) -# default to grant permissions -@click.option("--no-grant-permissions", is_flag=True, default=False) -def initdb(drop_tables, delay, no_grant_permissions): - """Initialize the database.""" - from time import sleep - from astra.utils import log - from astra.models import ( - base, - apogee, - apogeenet, - aspcap, - #boss, - #classifier, - ferre, - #lineforest, - #madgics, - #mdwarftype, - #slam, - #snow_white, - source, - spectrum, - #the_payne - ) +import typer +from typing import Optional +from typing_extensions import Annotated + +app = typer.Typer() + +@app.command() +def srun( + task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], + model: Annotated[str, typer.Argument( + help=( + "The input model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). " + ) + )] = None, + nodes: Annotated[int, typer.Option(help="The number of nodes to use.", min=1)] = 1, + procs: Annotated[int, typer.Option(help="The number of processes to use per node.", min=1)] = 1, + limit: Annotated[int, typer.Option(help="Limit the number of inputs.", min=1)] = None, + account: Annotated[str, typer.Option(help="Slurm account")] = "sdss-np", + partition: Annotated[str, typer.Option(help="Slurm partition")] = None, + gres: Annotated[str, typer.Option(help="Slurm generic resources")] = None, + mem: Annotated[str, typer.Option(help="Memory per node")] = None, + time: Annotated[str, typer.Option(help="Wall-time")] = "24:00:00", +): + """Distribute an Astra task over many nodes using Slurm.""" + + partition = partition or account + + import os + import sys + import numpy as np + import concurrent.futures + import subprocess + from datetime import datetime + from tempfile import TemporaryDirectory + from peewee import JOIN + from importlib import import_module + from astra import models, __version__, generate_queries_for_task + from astra.utils import silenced, expand_path + from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn + + # TODO: no hard coding of paths + ASTRA = "/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/astra/astra_dev/bin/new_astra" + + _, q = next(generate_queries_for_task(task, model, limit)) + + total = q.count() + workers = nodes * procs + limit = int(np.ceil(total / workers)) + today = datetime.now().strftime("%Y-%m-%d") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + transient=True + ) as p: + + executor = concurrent.futures.ProcessPoolExecutor(nodes) + + # Load a whole bunch of sruns in processes + futures = {} + with TemporaryDirectory(dir=expand_path("$PBS"), prefix=f"{task}-{today}-", delete=False) as td: + p.print(f"Working directory: {td}") + for n in range(nodes): + job_name = f"{task}" + (f"-{n}" if nodes > 1 else "") + # TODO: Let's not hard code this here. + commands = ["export CLUSTER=1", "echo hello"] + for page in range(n * procs, (n + 1) * procs): + commands.append(f"{ASTRA} run {task} {model} --limit {limit} --page {page + 1} &") + commands.append("wait") + + script_path = f"{td}/node_{n}.sh" + with open(script_path, "w") as fp: + fp.write("\n".join(commands)) + + os.system(f"chmod +x {script_path}") + executable = [ + "srun", + "--nodes=1", + f"--partition={partition}", + f"--account={account}", + f"--job-name={job_name}", + f"--time={time}", + f"--output={td}/{n}.out", + f"--error={td}/{n}.err", + ] + if mem is not None: + executable.append(f"--mem={mem}") + if gres is not None: + executable.append(f"--gres={gres}") + + executable.extend(["bash", "-c", f"{script_path}"]) + + t = p.add_task(description=f"Running {job_name}", total=None) + job = executor.submit( + subprocess.run, + executable, + capture_output=True + ) + futures[job] = (n, t) + + max_returncode = 0 + for future in concurrent.futures.as_completed(futures.keys()): + n, t = futures[future] + result = future.result() + if result.returncode == 0: + p.update(t, description=f"Completed") + p.remove_task(t) + else: + p.update(t, description=f"Error code {result.returncode} returned from {task}-{n}") + p.print(result.stderr.decode("utf-8")) + + max_returncode = max(max_returncode, result.returncode) + + sys.exit(max_returncode) + + +@app.command() +def run( + task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], + spectrum_model: Annotated[str, typer.Argument( + help=( + "The spectrum model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). " + "If `None` is given then all spectrum models accepted by the task will be analyzed." + ) + )] = None, + limit: Annotated[int, typer.Option(help="Limit the number of spectra.", min=1)] = None, + page: Annotated[int, typer.Option(help="Page to start results from (`limit` spectra per `page`).", min=1)] = None, +): + """Run an Astra task on spectra.""" + from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn - models = base.BaseModel.__subclasses__() - with base.database.atomic(): - if drop_tables: - log.warning(f"Dropping database tables in {delay} seconds..") - sleep(delay) - base.database.drop_tables(models, cascade=True) + from astra import models, __version__, generate_queries_for_task + from astra.utils import resolve_task + + fun = resolve_task(task) + + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as p: + t = p.add_task(description="Resolving task", total=None) + fun = resolve_task(task) + p.remove_task(t) + + messages = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeRemainingColumn(), + transient=False + ) as p: - base.database.create_tables(models) - log.info(f"Created {len(models)} database tables: {models}") - - if not no_grant_permissions: - schema = base.BaseModel._meta.schema - log.info(f"Granting permissions on schema {schema} to role 'sdss'") - base.database.execute_sql(f"GRANT ALL PRIVILEGES ON SCHEMA {schema} TO GROUP sdss;") - base.database.execute_sql(f"GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {schema} to sdss;") - - return None - - -@cli.command(context_settings=dict(ignore_unknown_options=True)) -@click.option("--apred", default=None) -@click.option("--run2d", default=None) -@click.option("--limit", default=None) -@click.option("--include-dr17", is_flag=True, default=False) -def migrate(apred, run2d, limit, include_dr17): - """Migrate data from the SDSS5 database.""" - - from astra.utils import log - from astra.models.source import Source - - from astra.migrations.apogee import ( - migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb, - migrate_apvisit_metadata_from_image_headers, - fix_version_id_edge_cases - ) + for model, q in generate_queries_for_task(fun, spectrum_model, limit, page=page): + t = p.add_task(description=f"Running {fun.__name__} on {model.__name__}", total=limit) + total = q.count() + p.update(t, total=total) + if total > 0: + for n, r in enumerate(fun(q), start=1): + p.update(t, advance=1, refresh=True) + messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}") + p.update(t, completed=True) + + list(map(typer.echo, messages)) + + +@app.command() +def migrate( + apred: Optional[str] = typer.Option(None, help="APOGEE data reduction pipeline version."), + run2d: Optional[str] = typer.Option(None, help="BOSS data reduction pipeline version."), +): + """Migrate spectra and auxillary information to the Astra database.""" + + import os + import multiprocessing as mp + from signal import SIGKILL + from rich.progress import Text, Progress, SpinnerColumn, Text, TextColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn, BarColumn, MofNCompleteColumn as _MofNCompleteColumn + + class MofNCompleteColumn(_MofNCompleteColumn): + def render(self, task): + completed = int(task.completed) + total = f"{int(task.total):,}" if task.total is not None else "?" + total_width = len(str(total)) + return Text( + f"{completed:{total_width},d}{self.separator}{total}", + style="progress.download", + ) + from astra.migrations.boss import ( - migrate_spectra_from_spall_file, + migrate_from_spall_file, migrate_specfull_metadata_from_image_headers ) + #from astra.migrations.apogee import ( + # migrate_apvisit_metadata_from_image_headers, + #) + from astra.migrations.new_apogee import ( + migrate_apogee_spectra_from_sdss5_apogee_drpdb, + migrate_dithered_metadata + ) from astra.migrations.catalog import ( - migrate_gaia_source_ids, migrate_healpix, - migrate_tic_v8_identifier, migrate_twomass_photometry, migrate_unwise_photometry, migrate_glimpse_photometry, + migrate_tic_v8_identifier, + migrate_gaia_source_ids, migrate_gaia_dr3_astrometry_and_photometry, migrate_zhang_stellar_parameters, - migrate_bailer_jones_distances + migrate_bailer_jones_distances, + migrate_gaia_synthetic_photometry ) from astra.migrations.misc import ( - compute_f_night_time_for_boss_visits, + compute_f_night_time_for_boss_visits, compute_f_night_time_for_apogee_visits, - set_missing_gaia_source_ids_to_null, + update_visit_spectra_counts, compute_n_neighborhood, - update_visit_spectra_counts + update_galactic_coordinates, + compute_w1mag_and_w2mag, + fix_unsigned_apogee_flags ) from astra.migrations.reddening import update_reddening - from astra.migrations.targeting import migrate_carton_assignments_to_bigbitfield - - log.info("Starting ingestion. This will take a long time.") - - if include_dr17: - log.info(f"Ingesting DR17 APOGEE spectra") - migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb(limit=limit) - - if run2d is not None: - log.info(f"Ingesting SDSS5 BOSS spectra with run2d={run2d}") - migrate_spectra_from_spall_file(run2d, limit=limit) - - log.info(f"Migrating BOSS metadata from headers") - migrate_specfull_metadata_from_image_headers() - - if apred is not None: - from astra.migrations.apogee import ( - migrate_1p3_apvisit_from_sdss5_apogee_drpdb as migrate_apvisit_from_sdss5_apogee_drpdb, - migrate_apstar_from_sdss5_database, - ) - - log.info(f"Ingesting SDSS5 APOGEE spectra with apred={apred}") - migrate_apvisit_from_sdss5_apogee_drpdb(apred, limit=limit) - - fix_version_id_edge_cases() - migrate_apstar_from_sdss5_database(apred, limit=limit) - - log.info(f"Migrating carton assignments") - migrate_carton_assignments_to_bigbitfield() - - log.info(f"Migrating HEALPix") - migrate_healpix() - - log.info(f"Migrating Gaia source identifiers") - migrate_gaia_source_ids() - - log.info(f"Migrating Gaia astrometry and photometry") - migrate_gaia_dr3_astrometry_and_photometry() - - log.info(f"Migrating Zhang et al stellar parameters") - migrate_zhang_stellar_parameters() - - log.info(f"Migrating 2MASS photometry") - migrate_twomass_photometry() - - log.info(f"Migrating unWISE photometry") - migrate_unwise_photometry() - - log.info(f"Migrating GLIMPSE photometry") - migrate_glimpse_photometry() - - log.info(f"Migrating TIC v8 identifiers") - migrate_tic_v8_identifier() - - log.info(f"Migrating Bailer-Jones distances") - migrate_bailer_jones_distances() - - log.info(f"Migration from SDSS5 catalogdb complete") - set_missing_gaia_source_ids_to_null() - - log.info(f"Computing f_night_fraction for BOSS visits") - compute_f_night_time_for_boss_visits() - - log.info(f"Computing f_night_fraction for APOGEE visits") - compute_f_night_time_for_apogee_visits() - - log.info(f"Computing neighbourhood size") - compute_n_neighborhood() - - log.info("Computing visit spectra counts") - update_visit_spectra_counts() - - log.info(f"Computing extinction") - update_reddening() - - log.info(f"Migrating apVisit metadata from image headers") - migrate_apvisit_metadata_from_image_headers() - log.info("Done") - - -@cli.command(context_settings=dict(ignore_unknown_options=True)) -@click.option("--slurm-profile", default=None, help="Use Slurm profile specified in Astra config file. If None is given, it will default to the profile for the task name, or `default`.") -@click.option("--slurm-dir", default=None) -@click.option("--limit", default=None, type=int) -@click.option("--nodes", default=1, type=int) -def create_mwm_products(slurm_profile, slurm_dir, limit, nodes): - - import os - from astra import config, log - from astra.utils import expand_path - resolved_task = "astra.products.mwm.create_all_mwm_products" - - slurm_profile_config = config.get("slurm", dict(profiles={})).get("profiles", {}) - if slurm_profile is not None: - if slurm_profile not in slurm_profile_config: - raise click.BadArgumentUsage(f"Cannot find Slurm profile '{slurm_profile}' in Astra config.") - else: - try_slurm_profile_names = (resolved_task, resolved_task.split(".")[-1], "default") - for slurm_profile in try_slurm_profile_names: - if slurm_profile in slurm_profile_config: - log.info(f"Using Slurm profile '{slurm_profile}'") - break - else: - raise click.BadOptionUsage(f"Cannot find any Slurm profile in Astra config. Use `--slurm-profile PROFILE` to specify. Tried: {', '.join(slurm_profile_config)}") - - if slurm_dir is None: - from datetime import datetime - from tempfile import mkdtemp - prefix = f"{datetime.now().strftime('%Y-%m-%d')}-{resolved_task.split('.')[-1][:30]}-" - slurm_dir = mkdtemp(prefix=prefix, dir=expand_path(f"$PBS/")) - os.chmod(slurm_dir, 0o755) - log.info(f"Using Slurm directory: {slurm_dir}") - job_name = f"{os.path.basename(slurm_dir)}" - else: - os.makedirs(slurm_dir, exist_ok=True) - job_name = f"{resolved_task.split('.')[-1]}" - - if limit is None: - from astra.models import Source - from astra.products.mwm_summary import DEFAULT_MWM_WHERE - from astra.models.apogee import ApogeeVisitSpectrum - from astra.models.boss import BossVisitSpectrum - from peewee import JOIN - - apreds = ("1.3", "dr17") - run2ds = ("v6_1_3", ) - - q_apogee = ( - ApogeeVisitSpectrum - .select(ApogeeVisitSpectrum.source_pk) - .distinct(ApogeeVisitSpectrum.source_pk) - .where(ApogeeVisitSpectrum.apred.in_(apreds)) - .alias("q_apogee") - ) - q_boss = ( - BossVisitSpectrum - .select(BossVisitSpectrum.source_pk) - .distinct(BossVisitSpectrum.source_pk) - .where(BossVisitSpectrum.run2d.in_(run2ds)) - .alias("q_boss") - ) - - q = ( - Source - .select() - .distinct(Source.pk) - .where( - Source.sdss_id.is_null(False) - & DEFAULT_MWM_WHERE - ) - .join(q_apogee, JOIN.LEFT_OUTER, on=(q_apogee.c.source_pk == Source.pk)) - .switch(Source) - .join(q_boss, JOIN.LEFT_OUTER, on=(q_boss.c.source_pk == Source.pk)) - .where( - (~q_apogee.c.source_pk.is_null()) - | (~q_boss.c.source_pk.is_null()) - ) - ) - - limit = q.count() - - import sys - if slurm_dir is None: - from datetime import datetime - from tempfile import mkdtemp - prefix = f"{datetime.now().strftime('%Y-%m-%d')}-{resolved_task.split('.')[-1][:30]}-" - if page: - prefix += f"{page}-" - slurm_dir = mkdtemp(prefix=prefix, dir=expand_path(f"$PBS/")) - os.chmod(slurm_dir, 0o755) - log.info(f"Using Slurm directory: {slurm_dir}") - job_name = f"{os.path.basename(slurm_dir)}" - else: - os.makedirs(slurm_dir, exist_ok=True) - job_name = f"{resolved_task.split('.')[-1]}" - - slurm_kwds = slurm_profile_config[slurm_profile] - - from astra.utils.slurm import SlurmTask, SlurmJob - - python_threads = slurm_kwds.pop("python_threads", 8) - if slurm_kwds is None: - pre_execute_commands = [] - else: - pre_execute_commands = [ - f"export OMP_NUM_THREADS={python_threads}", - f"export OPENBLAS_NUM_THREADS={python_threads}", - f"export MKL_NUM_THREADS={python_threads}", - f"export VECLIB_MAXIMUM_THREADS={python_threads}", - f"export NUMEXPR_NUM_THREADS={python_threads}" - ] - - n_proc = 32 - limit_per_proc = limit // (n_proc * nodes) + 1 - page_count = 0 - for node in range(nodes): - commands = [] - for page in range(1, n_proc + 1): - commands.append(f"astra execute astra.products.mwm.create_all_mwm_products --page {page + page_count} --limit {limit_per_proc} &") - page_count += page - commands.append("wait") - - this_slurm_dir = os.path.join(slurm_dir, f"{node}") - os.makedirs(this_slurm_dir, exist_ok=True) - os.chmod(this_slurm_dir, 0o755) - - slurm_job = SlurmJob( - [ - SlurmTask(pre_execute_commands + commands) - ], - f"{job_name}_{node}", - dir=this_slurm_dir, - **slurm_kwds, - ) - slurm_job_pk = slurm_job.submit() - - click.echo(f"{slurm_job_pk}") - sys.exit(0) - - -@cli.command(context_settings=dict(ignore_unknown_options=True)) -@click.option("--slurm-profile", default=None, help="Use Slurm profile specified in Astra config file. If None is given, it will default to the profile for the task name, or `default`.") -@click.option("--nodes", default=1) -@click.option("--procs-per-node", default=1) -@click.option("--limit", default=None) -@click.option("--page", default=None) -@click.argument("task", nargs=1) -@click.argument("spectrum_type", nargs=1) -@click.argument("kwargs_str", nargs=-1) -def new_execute(slurm_profile, nodes, procs_per_node, limit, page, task, spectrum_type, kwargs_str): - - import os - import numpy as np - import sys - from astra import config, __version__ - from astra.utils import expand_path, log, callable - - import pickle - from inspect import getfullargspec - from tqdm import tqdm - from peewee import chunked, JOIN - - from astra import models - from astra.models.source import Source - from astra.models.spectrum import Spectrum, SpectrumMixin - - - # Do some cleverness about the task name. - for prefix in ("", "astra.", "astra.pipelines.", f"astra.pipelines.{task}."): - try: - resolved_task = f"{prefix}{task}" - f = callable(resolved_task) - except: - None - else: - if prefix: - log.info(f"Resolved '{task}' -> '{resolved_task}'") - break - else: - # Raise exception on the no-prefix case. - f = callable(task) - - kwargs = {} - for each in kwargs_str: - k, v = each.split("=") - kwargs[k.lstrip("-")] = v - - - spectrum_model = getattr(models, spectrum_type) - try: - output_model = getfullargspec(f).annotations["return"].__args__[0] - except: - raise ValueError(f"Cannot infer output model for task {f}, is it missing a type annotation?") - - # Query for spectra that does not have a result in this output model - iterable = ( - spectrum_model - .select( - spectrum_model, - Source - ) - .join( - output_model, - JOIN.LEFT_OUTER, - on=( - (spectrum_model.spectrum_pk == output_model.spectrum_pk) - & (output_model.v_astra == __version__) - ) - ) - .switch(spectrum_model) - .join(Source, attr="source") # convenience to pre-fetch .source attribute on everything - .where(output_model.spectrum_pk.is_null()) + from astra.migrations.targeting import ( + migrate_carton_assignments_to_bigbitfield, + migrate_targeting_cartons ) - for k, v in kwargs.items(): - iterable = iterable.where(getattr(spectrum_model, k) == v) - print(f"requiring {k}={v}") - - if os.environ.get("CLUSTER", False): - q = ( - iterable - .paginate(int(page), int(limit)) - ) - # In Slurm environment, so just execute - for item in tqdm(f(q), total=int(limit)): - None - - else: - # Submit to Slurm - total = limit or iterable.count() - - print(f"Found {total}") - from astra.utils.slurm import SlurmTask, SlurmJob - - # Resolve slurm profile. - slurm_profile_config = config.get("slurm", dict(profiles={})).get("profiles", {}) - if slurm_profile is not None: - if slurm_profile not in slurm_profile_config: - raise click.BadArgumentUsage(f"Cannot find Slurm profile '{slurm_profile}' in Astra config.") - else: - try_slurm_profile_names = (resolved_task, resolved_task.split(".")[-1], "default") - for slurm_profile in try_slurm_profile_names: - if slurm_profile in slurm_profile_config: - log.info(f"Using Slurm profile '{slurm_profile}'") - break - else: - raise click.BadOptionUsage(f"Cannot find any Slurm profile in Astra config. Use `--slurm-profile PROFILE` to specify. Tried: {', '.join(slurm_profile_config)}") - - slurm_kwds = slurm_profile_config[slurm_profile] - - n_procs = procs_per_node * nodes - batch_size = int(np.ceil(total / (n_procs))) - - from datetime import datetime - from tempfile import mkdtemp - - python_threads = slurm_kwds.pop("python_threads", 8) - if slurm_kwds is None: - pre_execute_commands = [] - else: - pre_execute_commands = [ - f"export OMP_NUM_THREADS={python_threads}", - f"export OPENBLAS_NUM_THREADS={python_threads}", - f"export MKL_NUM_THREADS={python_threads}", - f"export VECLIB_MAXIMUM_THREADS={python_threads}", - f"export NUMEXPR_NUM_THREADS={python_threads}" - ] - - start_page = 1 - - for node in range(nodes): - - commands = [] - for proc in range(procs_per_node): - - page = start_page + node * procs_per_node + proc - - command = f"astra new-execute --page {page} --limit {batch_size} {resolved_task} {spectrum_type} {' '.join(kwargs_str)}" - commands.append(command) - - prefix = f"{datetime.now().strftime('%Y-%m-%d')}-{resolved_task.split('.')[-1][:30]}-" - if page: - prefix += f"{node}-" - slurm_dir = mkdtemp(prefix=prefix, dir=expand_path(f"$PBS/")) - os.chmod(slurm_dir, 0o755) - log.info(f"Using Slurm directory: {slurm_dir}") - job_name = f"{os.path.basename(slurm_dir)}" - - slurm_job = SlurmJob( - [ - SlurmTask(pre_execute_commands + [c]) for c in commands - ], - job_name, - dir=slurm_dir, - **slurm_kwds, - ) - slurm_job_pk = slurm_job.submit() - print(slurm_job_pk) - - - - -@cli.command(context_settings=dict(ignore_unknown_options=True)) -@click.option("--slurm", is_flag=True, default=False, help="Execute through Slurm (see --slurm-profile)") -@click.option("--slurm-profile", default=None, help="Use Slurm profile specified in Astra config file. If None is given, it will default to the profile for the task name, or `default`.") -@click.option("--slurm-dir", default=None) -@click.option("--page", default=None, type=int) -@click.option("--limit", default=None, type=int) -@click.option("--kwargs-path") -@click.argument("task") -@click.argument("spectra", nargs=-1) -def execute(slurm, slurm_profile, slurm_dir, page, limit, kwargs_path, task, spectra): - """ - Execute a task on one or many spectra. - """ - # Resolve spectrum ids. - #if len(spectra) == 0: - # raise click.UsageError("No spectral model or spectrum identifiers given.") - - import os - import sys - from astra import config, __version__ - from astra.utils import expand_path, log, callable - - import pickle - from inspect import getfullargspec - from tqdm import tqdm - from peewee import chunked, JOIN - - from astra import models - from astra.models.source import Source - from astra.models.spectrum import Spectrum, SpectrumMixin - - kwargs = {} - if kwargs_path is not None: - with open(kwargs_path, "rb") as fp: - kwargs = pickle.load(fp) - - # Parse any additional keyword arguments which unfortunately get lumped into `spectra`. - # TODO: THere must be a nicer way to parse this using click. - _spectra = [] - for arg in spectra: - if arg.startswith("--"): - k, *v = arg[2:].split("=") - k = k.replace("-", "_") - kwargs[k] = "=".join(v).strip('"') - else: - _spectra.append(arg) - - # Do some cleverness about the task name. - for prefix in ("", "astra.", "astra.pipelines.", f"astra.pipelines.{task}."): - try: - resolved_task = f"{prefix}{task}" - f = callable(resolved_task) - except: - None - else: - if prefix: - log.info(f"Resolved '{task}' -> '{resolved_task}'") - break - else: - # Raise exception on the no-prefix case. - f = callable(task) - - # TODO: This is all a bit of spaghetti code. Refactor - - if slurm: - # Check that there is any work to do before submitting a job. - - try: - spectrum_pks = list(map(int, _spectra)) - except ValueError: - if len(_spectra) > 1: - raise NotImplementedError("Only one spectrum model allowed for now. This can be changed.") - - # If the first item has a default, then don't do anything special. - model_name, = _spectra - spectrum_model = getattr(models, model_name) - try: - output_model = getfullargspec(f).annotations["return"].__args__[0] - except: - raise a - raise ValueError(f"Cannot infer output model for task {f}, is it missing a type annotation?") - - # Query for spectra that does not have a result in this output model - iterable = ( - spectrum_model - .select( - spectrum_model, - Source - ) - .join( - output_model, - JOIN.LEFT_OUTER, - on=(spectrum_model.spectrum_pk == output_model.spectrum_pk) - ) - .switch(spectrum_model) - .join(Source, attr="source") # convenience to pre-fetch .source attribute on everything - .where(output_model.spectrum_pk.is_null()) - .limit(limit) - ) - total = limit or iterable.count() - log.info(f"Found at least {total} {model_name} spectra that do not have results in {output_model}") + from astra.utils import silenced - else: - total = len(spectrum_pks) - - argspec = getfullargspec(f) - - if len(argspec.defaults) != len(argspec.args) and total == 0: - # Nothing to do. - log.info(f"No spectra to process.") - sys.exit(0) - + ptq = [] + try: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + transient=True + ) as progress: + + def process_task(target, *args, description=None, **kwargs): + queue = mp.Queue() + + kwds = dict(queue=queue) + kwds.update(kwargs) + process = mp.Process(target=target, args=args, kwargs=kwds) + process.start() + + task = progress.add_task(description=(description or ""), total=None) + return (process, task, queue) - from astra.utils.slurm import SlurmTask, SlurmJob - - # Resolve slurm profile. - slurm_profile_config = config.get("slurm", dict(profiles={})).get("profiles", {}) - if slurm_profile is not None: - if slurm_profile not in slurm_profile_config: - raise click.BadArgumentUsage(f"Cannot find Slurm profile '{slurm_profile}' in Astra config.") - else: - try_slurm_profile_names = (resolved_task, resolved_task.split(".")[-1], "default") - for slurm_profile in try_slurm_profile_names: - if slurm_profile in slurm_profile_config: - log.info(f"Using Slurm profile '{slurm_profile}'") - break - else: - raise click.BadOptionUsage(f"Cannot find any Slurm profile in Astra config. Use `--slurm-profile PROFILE` to specify. Tried: {', '.join(slurm_profile_config)}") - slurm_kwds = slurm_profile_config[slurm_profile] - - # Submit this job. #TODO: Is there a way for Click to reconstruct the command for us? - command = "astra execute " - if page: - command += f"--page {page} " - if limit: - command += f"--limit {limit} " - if kwargs_path: - command += f"--kwargs-path {kwargs_path} " - command += f"{resolved_task} " - command += " ".join(spectra) - - if slurm_dir is None: - from datetime import datetime - from tempfile import mkdtemp - prefix = f"{datetime.now().strftime('%Y-%m-%d')}-{resolved_task.split('.')[-1][:30]}-" - if page: - prefix += f"{page}-" - slurm_dir = mkdtemp(prefix=prefix, dir=expand_path(f"$PBS/")) - os.chmod(slurm_dir, 0o755) - log.info(f"Using Slurm directory: {slurm_dir}") - job_name = f"{os.path.basename(slurm_dir)}" - else: - os.makedirs(slurm_dir, exist_ok=True) - job_name = f"{resolved_task.split('.')[-1]}" - - python_threads = slurm_kwds.pop("python_threads", 8) - if slurm_kwds is None: - pre_execute_commands = [] - else: - pre_execute_commands = [ - f"export OMP_NUM_THREADS={python_threads}", - f"export OPENBLAS_NUM_THREADS={python_threads}", - f"export MKL_NUM_THREADS={python_threads}", - f"export VECLIB_MAXIMUM_THREADS={python_threads}", - f"export NUMEXPR_NUM_THREADS={python_threads}" + if apred is not None or run2d is not None: + if apred is not None: + if apred == "dr17": + from astra.migrations.apogee import migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb + ptq.append(process_task(migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb, description="Ingesting APOGEE dr17 spectra")) + else: + ptq.append(process_task(migrate_apogee_spectra_from_sdss5_apogee_drpdb, apred, description=f"Ingesting APOGEE {apred} spectra")) + if run2d is not None: + ptq.append(process_task(migrate_from_spall_file, run2d, description=f"Ingesting BOSS {run2d} spectra")) + + awaiting = set(t for p, t, q in ptq) + while awaiting: + for p, t, q in ptq: + try: + r = q.get(False) + if r is Ellipsis: + progress.update(t, completed=True) + awaiting.remove(t) + p.join() + progress.update(t, visible=False) + else: + progress.update(t, **r) + if "completed" in r and r.get("completed", None) == 0: + # reset the task + progress.reset(t) + except mp.queues.Empty: + pass + + # Now that we have sources and spectra, we can do other things. + ptq = [ + process_task(migrate_gaia_source_ids, description="Ingesting Gaia DR3 source IDs"), + process_task(migrate_twomass_photometry, description="Ingesting 2MASS photometry"), + process_task(migrate_unwise_photometry, description="Ingesting unWISE photometry"), + process_task(migrate_glimpse_photometry, description="Ingesting GLIMPSE photometry"), + #process_task(migrate_specfull_metadata_from_image_headers, description="Ingesting specFull metadata"), + + + process_task(migrate_dithered_metadata, description="Ingesting APOGEE dithered metadata"), + #process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"), + process_task(migrate_healpix, description="Ingesting HEALPix values"), + process_task(migrate_tic_v8_identifier, description="Ingesting TIC v8 identifiers"), + process_task(update_galactic_coordinates, description="Computing Galactic coordinates"), + process_task(fix_unsigned_apogee_flags, description="Fix unsigned APOGEE flags"), + process_task(migrate_targeting_cartons, description="Ingesting targeting cartons"), + process_task(compute_f_night_time_for_apogee_visits, description="Computing f_night for APOGEE visits"), + process_task(update_visit_spectra_counts, description="Updating visit spectra counts"), ] - slurm_job = SlurmJob( - [ - SlurmTask(pre_execute_commands + [command]) - ], - job_name, - dir=slurm_dir, - **slurm_kwds, - ) - slurm_job_pk = slurm_job.submit() + # reddening needs unwise, 2mass, glimpse, + task_gaia, task_twomass, task_unwise, task_glimpse, task_specfull, *_ = [t for p, t, q in ptq] + reddening_requires = {task_twomass, task_unwise, task_glimpse, task_gaia} + started_reddening = False + awaiting = set(t for p, t, q in ptq) + while awaiting: + additional_tasks = [] + for p, t, q in ptq: + try: + r = q.get(False) + if r is Ellipsis: + progress.update(t, completed=True) + awaiting.remove(t) + p.join() + progress.update(t, visible=False) + if t == task_gaia: + # Add a bunch more tasks! + new_tasks = [ + process_task(migrate_gaia_dr3_astrometry_and_photometry, description="Ingesting Gaia DR3 astrometry and photometry"), + process_task(migrate_zhang_stellar_parameters, description="Ingesting Zhang stellar parameters"), + process_task(migrate_bailer_jones_distances, description="Ingesting Bailer-Jones distances"), + # commented out only because we are getting deadlock errors caused by some other process. + #process_task(migrate_gaia_synthetic_photometry, description="Ingesting Gaia synthetic photometry"), + process_task(compute_n_neighborhood, description="Computing n_neighborhood"), + ] + reddening_requires.update({t for p, t, q in new_tasks[:3]}) # reddening needs Gaia astrometry, Zhang parameters, and Bailer-Jones distances + additional_tasks.extend(new_tasks) + if t == task_specfull: + additional_tasks.append( + process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits") + ) + if t == task_unwise: + additional_tasks.append( + process_task(compute_w1mag_and_w2mag, description="Computing W1, W2 mags") + ) + if not started_reddening and not (awaiting & reddening_requires): + started_reddening = True + #additional_tasks.append( + # process_task(update_reddening, description="Computing extinction") + #) + else: + progress.update(t, **r) + if "completed" in r and r.get("completed", None) == 0: + # reset the task + progress.reset(t) + + except mp.queues.Empty: + pass + + ptq.extend(additional_tasks) + awaiting |= set(t for p, t, q in additional_tasks) + + except KeyboardInterrupt: + """ + with silenced(): + import psutil + parent = psutil.Process(os.getpid()) + for child in parent.children(recursive=True): + child.kill() + """ + raise KeyboardInterrupt + + + +@app.command() +def init( + drop_tables: Optional[bool] = typer.Option(False, help="Drop tables if they exist."), + delay: Optional[int] = typer.Option(10, help="Delay in seconds to wait.") +): + """Initialize the Astra database.""" - click.echo(f"{slurm_job_pk}") - sys.exit(0) - - try: - spectrum_pks = list(map(int, _spectra)) - except ValueError: - if len(_spectra) > 1: - raise NotImplementedError("Only one spectrum model allowed for now. This can be changed.") - - model_name, = _spectra - spectrum_model = getattr(models, model_name) - argspec = getfullargspec(f) - - try: - output_model = argspec.annotations["return"].__args__[0] - except: - raise ValueError(f"Cannot infer output model for task {f}, is it missing a type annotation?") - - # Query for spectra that does not have a result in this output model, for this astra version - where = output_model.spectrum_pk.is_null() - if hasattr(spectrum_model, "v_astra"): - print(f"Restricting {spectrum_model} to have v_astra={__version__}") - where = where & (spectrum_model.v_astra == __version__) - - iterable = ( - spectrum_model - .select( - spectrum_model, - Source - ) - .join( - output_model, - JOIN.LEFT_OUTER, - on=( - (spectrum_model.spectrum_pk == output_model.spectrum_pk) - & (output_model.v_astra == __version__) - ) - ) - .switch(spectrum_model) - .join(Source, attr="source") # convenience to pre-fetch .source attribute on everything - .where(where) + from time import sleep + from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn + from importlib import import_module + from astra.models.base import (database, BaseModel) + from astra.models.pipeline import PipelineOutputModel + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + TimeRemainingColumn(), + transient=True + ) as progress: + + init_model_packages = ( + "apogee", + "boss", + "bossnet", + "apogeenet", + "astronn_dist", + "astronn", + "source", + "spectrum", ) - print(f"kwargs are {kwargs}") - if kwargs: - for_deletion = [] - for k, v in kwargs.items(): - if hasattr(spectrum_model, k): - print(f"requiring {k} = {v}") - iterable = iterable.where(getattr(spectrum_model, k) == v) - for_deletion.append(k) - if for_deletion: - for k in for_deletion: - kwargs.pop(k) - - if page: - iterable = ( - iterable - .paginate(page, limit) - ) - else: - iterable = iterable.limit(limit) - total = limit or iterable.count() - - else: - if not spectrum_pks: - spectrum_pks.extend(kwargs.pop("spectra", [])) - else: - if "spectra" in kwargs: - raise ValueError("`spectra` given in `kwargs_path` and in command line") - - argspec = getfullargspec(f) - - if len(argspec.defaults) == len(argspec.args): - # It has a default for everything, and no spectrum model given, so give nothing - iterable = None - elif spectrum_pks: - example = Spectrum.get(spectrum_pks[0]) - spectrum_model = None - for expr, field in example.dependencies(): - if SpectrumMixin not in field.model.__mro__: - continue - try: - q = list(field.model.select().where(expr)) - except: - continue - else: - if q: - spectrum_model = field.model - log.info(f"Identified input spectra as type `{spectrum_model}`") - break - - log.warning(f"All given spectrum identifiers should come from the same model type") - - # SQLite has a limit on how many SQL variables can be used in a transaction. - def yield_spectrum_chunks(): - for chunk in chunked(spectrum_pks, 10_000): - yield from ( - spectrum_model - .select( - spectrum_model, - Source - ) - .join(Source, attr="source") - .where(spectrum_model.spectrum_pk.in_(chunk)) - ) - - iterable = yield_spectrum_chunks() - total = len(spectrum_pks) - else: - raise click.UsageError("Could not resolve spectrum identifiers.") - - if page is not None: - kwargs["page"] = page - if limit is not None: - kwargs["limit"] = limit + for package in init_model_packages: + import_module(f"astra.models.{package}") - if iterable is None: - for result in tqdm(f(**kwargs), total=0, unit=" spectra"): - None - - else: - for result in tqdm(f(iterable, **kwargs), total=total, unit=" spectra"): - None - - return None - - - -@cli.command() -@click.argument("paths", nargs=-1) -def run(paths, **kwargs): - """Execute one or many tasks.""" - import os - import json - from importlib import import_module - from astra.utils import log, expand_path - from astra.database.astradb import DataProduct - from tqdm import tqdm - - for path in paths: - log.info(f"Running {path}") - with open(path, "r") as fp: - content = json.load(fp) - - instructions = [content] if isinstance(content, dict) else content - N = len(instructions) - for i, instruction in enumerate(instructions, start=1): - log.info(f"Starting on instruction {i}/{N} in {path}") - - task_kwargs = instruction.get("task_kwargs", {}) - - # A couple of unfortunate hacks to fix instructions that were incomplete. - if (instruction["task_callable"] == "astra.contrib.aspcap.abundances.aspcap_abundances"): - if "pwd" not in task_kwargs: - # This one will fail. - log.warning(f"Skipping {i}-th (1-indexed) instruction because it's ASPCAP abundances without a pwd") - continue - - # Check if outputs already exist. - pwd = task_kwargs["pwd"] - if os.path.exists(os.path.join(expand_path(pwd), "stdout")): - log.warning(F"Skipping {i}-th (1-indexed) instruction because it's ASPCAP abundances and the outputs already exist") - continue - - # Get the task executable. - module_name, function_name = instruction["task_callable"].rsplit(".", 1) - module = import_module(module_name) - task_callable = getattr(module, function_name) - - has_data_products = "data_product" in task_kwargs # TODO: this special key should be defined elsewhere - if has_data_products: - # Resolve the data products - input_data_products = task_kwargs.pop("data_product", []) - if isinstance(input_data_products, str): - input_data_products = json.loads(input_data_products) - # The same data product can appear in this list multiple times, and we want to preserve order. - q = DataProduct.select().where(DataProduct.id << input_data_products) - unique_data_products = { dp.id: dp for dp in q } - task_kwargs["data_product"] = [unique_data_products[dp_pk] for dp_pk in input_data_products] - - log.info(f"Executing..") - try: - results = task_callable(**task_kwargs) - for result in results: - None - except: - log.exception(f"Exception in {task_callable} with {task_kwargs}") - raise - continue - - - log.info(f"Done") - - # Remove the path now that we're done. - #try: - # os.unlink(path) - #except: - # None - + models = set(BaseModel.__subclasses__()) - {PipelineOutputModel} + + if drop_tables: + tables_to_drop = [m for m in models if m.table_exists()] + if delay > 0: + t = progress.add_task(description=f"About to drop {len(tables_to_drop)} tables..", total=delay) + for i in range(delay): + progress.advance(t) + sleep(1) + + with database.atomic(): + database.drop_tables(tables_to_drop, cascade=True) + progress.remove_task(t) + t = progress.add_task(description="Creating tables", total=len(models)) + with database.atomic(): + database.create_tables(models) + if __name__ == "__main__": - cli(obj=dict()) + app() \ No newline at end of file diff --git a/bin/new_astra b/bin/new_astra deleted file mode 100755 index 55656fe..0000000 --- a/bin/new_astra +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -import typer -from typing import Optional -from typing_extensions import Annotated - -app = typer.Typer() - -@app.command() -def srun( - task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], - model: Annotated[str, typer.Argument( - help=( - "The input model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). " - ) - )] = None, - nodes: Annotated[int, typer.Option(help="The number of nodes to use.", min=1)] = 1, - procs: Annotated[int, typer.Option(help="The number of processes to use per node.", min=1)] = 1, - limit: Annotated[int, typer.Option(help="Limit the number of inputs.", min=1)] = None, - account: Annotated[str, typer.Option(help="Slurm account")] = "sdss-np", - partition: Annotated[str, typer.Option(help="Slurm partition")] = None, - gres: Annotated[str, typer.Option(help="Slurm generic resources")] = None, - mem: Annotated[str, typer.Option(help="Memory per node")] = None, - time: Annotated[str, typer.Option(help="Wall-time")] = "24:00:00", -): - """Distribute an Astra task over many nodes using Slurm.""" - - partition = partition or account - - import os - import sys - import numpy as np - import concurrent.futures - import subprocess - from datetime import datetime - from tempfile import TemporaryDirectory - from peewee import JOIN - from importlib import import_module - from astra import models, __version__, generate_queries_for_task - from astra.utils import silenced, expand_path - from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn - - # TODO: no hard coding of paths - ASTRA = "/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/astra/astra_dev/bin/new_astra" - - _, q = next(generate_queries_for_task(task, model, limit)) - - total = q.count() - workers = nodes * procs - limit = int(np.ceil(total / workers)) - today = datetime.now().strftime("%Y-%m-%d") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - transient=True - ) as p: - - executor = concurrent.futures.ProcessPoolExecutor(nodes) - - # Load a whole bunch of sruns in processes - futures = {} - with TemporaryDirectory(dir=expand_path("$PBS"), prefix=f"{task}-{today}-", delete=False) as td: - p.print(f"Working directory: {td}") - for n in range(nodes): - job_name = f"{task}" + (f"-{n}" if nodes > 1 else "") - # TODO: Let's not hard code this here. - commands = ["export CLUSTER=1", "echo hello"] - for page in range(n * procs, (n + 1) * procs): - commands.append(f"{ASTRA} run {task} {model} --limit {limit} --page {page + 1} &") - commands.append("wait") - - script_path = f"{td}/node_{n}.sh" - with open(script_path, "w") as fp: - fp.write("\n".join(commands)) - - os.system(f"chmod +x {script_path}") - executable = [ - "srun", - "--nodes=1", - f"--partition={partition}", - f"--account={account}", - f"--job-name={job_name}", - f"--time={time}", - f"--output={td}/{n}.out", - f"--error={td}/{n}.err", - ] - if mem is not None: - executable.append(f"--mem={mem}") - if gres is not None: - executable.append(f"--gres={gres}") - - executable.extend(["bash", "-c", f"{script_path}"]) - - t = p.add_task(description=f"Running {job_name}", total=None) - job = executor.submit( - subprocess.run, - executable, - capture_output=True - ) - futures[job] = (n, t) - - max_returncode = 0 - for future in concurrent.futures.as_completed(futures.keys()): - n, t = futures[future] - result = future.result() - if result.returncode == 0: - p.update(t, description=f"Completed") - p.remove_task(t) - else: - p.update(t, description=f"Error code {result.returncode} returned from {task}-{n}") - p.print(result.stderr.decode("utf-8")) - - max_returncode = max(max_returncode, result.returncode) - - sys.exit(max_returncode) - - -@app.command() -def run( - task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], - spectrum_model: Annotated[str, typer.Argument( - help=( - "The spectrum model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). " - "If `None` is given then all spectrum models accepted by the task will be analyzed." - ) - )] = None, - limit: Annotated[int, typer.Option(help="Limit the number of spectra.", min=1)] = None, - page: Annotated[int, typer.Option(help="Page to start results from (`limit` spectra per `page`).", min=1)] = None, -): - """Run an Astra task on spectra.""" - from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn - - from astra import models, __version__, generate_queries_for_task - from astra.utils import resolve_task - - fun = resolve_task(task) - - with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as p: - t = p.add_task(description="Resolving task", total=None) - fun = resolve_task(task) - p.remove_task(t) - - messages = [] - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - MofNCompleteColumn(), - TimeRemainingColumn(), - transient=False - ) as p: - - for model, q in generate_queries_for_task(fun, spectrum_model, limit, page=page): - t = p.add_task(description=f"Running {fun.__name__} on {model.__name__}", total=limit) - total = q.count() - p.update(t, total=total) - if total > 0: - for n, r in enumerate(fun(q), start=1): - p.update(t, advance=1, refresh=True) - messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}") - p.update(t, completed=True) - - list(map(typer.echo, messages)) - - -@app.command() -def migrate( - apred: Optional[str] = typer.Option(None, help="APOGEE data reduction pipeline version."), - run2d: Optional[str] = typer.Option(None, help="BOSS data reduction pipeline version."), -): - """Migrate spectra and auxillary information to the Astra database.""" - - import os - import multiprocessing as mp - from signal import SIGKILL - from rich.progress import Text, Progress, SpinnerColumn, Text, TextColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn, BarColumn, MofNCompleteColumn as _MofNCompleteColumn - - class MofNCompleteColumn(_MofNCompleteColumn): - def render(self, task): - completed = int(task.completed) - total = f"{int(task.total):,}" if task.total is not None else "?" - total_width = len(str(total)) - return Text( - f"{completed:{total_width},d}{self.separator}{total}", - style="progress.download", - ) - - from astra.migrations.boss import ( - migrate_from_spall_file, - migrate_specfull_metadata_from_image_headers - ) - #from astra.migrations.apogee import ( - # migrate_apvisit_metadata_from_image_headers, - #) - from astra.migrations.new_apogee import ( - migrate_apogee_spectra_from_sdss5_apogee_drpdb, - migrate_dithered_metadata - ) - from astra.migrations.catalog import ( - migrate_healpix, - migrate_twomass_photometry, - migrate_unwise_photometry, - migrate_glimpse_photometry, - migrate_tic_v8_identifier, - migrate_gaia_source_ids, - migrate_gaia_dr3_astrometry_and_photometry, - migrate_zhang_stellar_parameters, - migrate_bailer_jones_distances, - migrate_gaia_synthetic_photometry - ) - from astra.migrations.misc import ( - compute_f_night_time_for_boss_visits, - compute_f_night_time_for_apogee_visits, - update_visit_spectra_counts, - compute_n_neighborhood, - update_galactic_coordinates, - compute_w1mag_and_w2mag, - fix_unsigned_apogee_flags - ) - from astra.migrations.reddening import update_reddening - from astra.migrations.targeting import ( - migrate_carton_assignments_to_bigbitfield, - migrate_targeting_cartons - ) - from astra.utils import silenced - - ptq = [] - try: - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - MofNCompleteColumn(), - TimeElapsedColumn(), - TimeRemainingColumn(), - transient=True - ) as progress: - - def process_task(target, *args, description=None, **kwargs): - queue = mp.Queue() - - kwds = dict(queue=queue) - kwds.update(kwargs) - process = mp.Process(target=target, args=args, kwargs=kwds) - process.start() - - task = progress.add_task(description=(description or ""), total=None) - return (process, task, queue) - - - if apred is not None or run2d is not None: - if apred is not None: - if apred == "dr17": - from astra.migrations.apogee import migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb - ptq.append(process_task(migrate_sdss4_dr17_apogee_spectra_from_sdss5_catalogdb, description="Ingesting APOGEE dr17 spectra")) - else: - ptq.append(process_task(migrate_apogee_spectra_from_sdss5_apogee_drpdb, apred, description=f"Ingesting APOGEE {apred} spectra")) - if run2d is not None: - ptq.append(process_task(migrate_from_spall_file, run2d, description=f"Ingesting BOSS {run2d} spectra")) - - awaiting = set(t for p, t, q in ptq) - while awaiting: - for p, t, q in ptq: - try: - r = q.get(False) - if r is Ellipsis: - progress.update(t, completed=True) - awaiting.remove(t) - p.join() - progress.update(t, visible=False) - else: - progress.update(t, **r) - if "completed" in r and r.get("completed", None) == 0: - # reset the task - progress.reset(t) - except mp.queues.Empty: - pass - - # Now that we have sources and spectra, we can do other things. - ptq = [ - process_task(migrate_gaia_source_ids, description="Ingesting Gaia DR3 source IDs"), - process_task(migrate_twomass_photometry, description="Ingesting 2MASS photometry"), - process_task(migrate_unwise_photometry, description="Ingesting unWISE photometry"), - process_task(migrate_glimpse_photometry, description="Ingesting GLIMPSE photometry"), - process_task(migrate_specfull_metadata_from_image_headers, description="Ingesting specFull metadata"), - - - process_task(migrate_dithered_metadata, description="Ingesting APOGEE dithered metadata"), - #process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"), - process_task(migrate_healpix, description="Ingesting HEALPix values"), - process_task(migrate_tic_v8_identifier, description="Ingesting TIC v8 identifiers"), - process_task(update_galactic_coordinates, description="Computing Galactic coordinates"), - process_task(fix_unsigned_apogee_flags, description="Fix unsigned APOGEE flags"), - process_task(migrate_targeting_cartons, description="Ingesting targeting cartons"), - process_task(compute_f_night_time_for_apogee_visits, description="Computing f_night for APOGEE visits"), - process_task(update_visit_spectra_counts, description="Updating visit spectra counts"), - ] - # reddening needs unwise, 2mass, glimpse, - task_gaia, task_twomass, task_unwise, task_glimpse, task_specfull, *_ = [t for p, t, q in ptq] - reddening_requires = {task_twomass, task_unwise, task_glimpse, task_gaia} - started_reddening = False - awaiting = set(t for p, t, q in ptq) - while awaiting: - additional_tasks = [] - for p, t, q in ptq: - try: - r = q.get(False) - if r is Ellipsis: - progress.update(t, completed=True) - awaiting.remove(t) - p.join() - progress.update(t, visible=False) - if t == task_gaia: - # Add a bunch more tasks! - new_tasks = [ - process_task(migrate_gaia_dr3_astrometry_and_photometry, description="Ingesting Gaia DR3 astrometry and photometry"), - process_task(migrate_zhang_stellar_parameters, description="Ingesting Zhang stellar parameters"), - process_task(migrate_bailer_jones_distances, description="Ingesting Bailer-Jones distances"), - # commented out only because we are getting deadlock errors caused by some other process. - #process_task(migrate_gaia_synthetic_photometry, description="Ingesting Gaia synthetic photometry"), - process_task(compute_n_neighborhood, description="Computing n_neighborhood"), - ] - reddening_requires.update({t for p, t, q in new_tasks[:3]}) # reddening needs Gaia astrometry, Zhang parameters, and Bailer-Jones distances - additional_tasks.extend(new_tasks) - if t == task_specfull: - additional_tasks.append( - process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits") - ) - if t == task_unwise: - additional_tasks.append( - process_task(compute_w1mag_and_w2mag, description="Computing W1, W2 mags") - ) - if not started_reddening and not (awaiting & reddening_requires): - started_reddening = True - #additional_tasks.append( - # process_task(update_reddening, description="Computing extinction") - #) - else: - progress.update(t, **r) - if "completed" in r and r.get("completed", None) == 0: - # reset the task - progress.reset(t) - - except mp.queues.Empty: - pass - - ptq.extend(additional_tasks) - awaiting |= set(t for p, t, q in additional_tasks) - - except KeyboardInterrupt: - """ - with silenced(): - import psutil - parent = psutil.Process(os.getpid()) - for child in parent.children(recursive=True): - child.kill() - """ - raise KeyboardInterrupt - - - -@app.command() -def init( - drop_tables: Optional[bool] = typer.Option(False, help="Drop tables if they exist."), - delay: Optional[int] = typer.Option(10, help="Delay in seconds to wait.") -): - """Initialize the Astra database.""" - - from time import sleep - from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn - from importlib import import_module - from astra.models.base import (database, BaseModel) - from astra.models.pipeline import PipelineOutputModel - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - TimeRemainingColumn(), - transient=True - ) as progress: - - init_model_packages = ( - "apogee", - "boss", - "bossnet", - "apogeenet", - "astronn_dist", - "astronn", - "source", - "spectrum", - ) - for package in init_model_packages: - import_module(f"astra.models.{package}") - - models = set(BaseModel.__subclasses__()) - {PipelineOutputModel} - - if drop_tables: - tables_to_drop = [m for m in models if m.table_exists()] - if delay > 0: - t = progress.add_task(description=f"About to drop {len(tables_to_drop)} tables..", total=delay) - for i in range(delay): - progress.advance(t) - sleep(1) - - with database.atomic(): - database.drop_tables(tables_to_drop, cascade=True) - progress.remove_task(t) - - t = progress.add_task(description="Creating tables", total=len(models)) - with database.atomic(): - database.create_tables(models) - - -if __name__ == "__main__": - app() \ No newline at end of file diff --git a/src/astra/migrations/catalog.py b/src/astra/migrations/catalog.py index 5982bce..e35c33b 100644 --- a/src/astra/migrations/catalog.py +++ b/src/astra/migrations/catalog.py @@ -2,20 +2,13 @@ from typing import Optional from tqdm import tqdm from peewee import chunked, IntegerField, fn, JOIN, IntegrityError -from astra.models.base import database -from astra.models.source import Source -from astra.migrations.sdss5db.utils import get_approximate_rows +#from astra.migrations.sdss5db.utils import get_approximate_rows from astra.migrations.utils import NoQueue from astra.utils import log, flatten import numpy as np def migrate_healpix( - where=( - Source.healpix.is_null() - & Source.ra.is_null(False) - & Source.dec.is_null(False) - ), batch_size: Optional[int] = 500, limit: Optional[int] = None, nside: Optional[int] = 128, @@ -37,6 +30,9 @@ def migrate_healpix( :param lonlat: [optional] The HEALPix map is oriented in longitude and latitude coordinates. """ + from astra.models.base import database + from astra.models.source import Source + from healpy import ang2pix if queue is None: queue = NoQueue() @@ -48,7 +44,11 @@ def migrate_healpix( Source.ra, Source.dec, ) - .where(where) + .where( + Source.healpix.is_null() + & Source.ra.is_null(False) + & Source.dec.is_null(False) + ) .limit(limit) ) @@ -71,11 +71,13 @@ def migrate_healpix( def migrate_bailer_jones_distances( - where=(Source.r_med_geo.is_null() & Source.gaia_dr3_source_id.is_null(False) & (Source.gaia_dr3_source_id > 0)), batch_size=500, limit=None, queue=None ): + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() @@ -84,9 +86,10 @@ def migrate_bailer_jones_distances( q = ( Source .select() + .where( + (Source.r_med_geo.is_null() & Source.gaia_dr3_source_id.is_null(False) & (Source.gaia_dr3_source_id > 0)) + ) ) - if where: - q = q.where(where) q = ( q @@ -149,11 +152,13 @@ def migrate_bailer_jones_distances( def migrate_gaia_synthetic_photometry( - where=(Source.gaia_dr3_source_id.is_null(False) & Source.g_sdss_mag.is_null()), batch_size=500, limit=None, queue=None ): + from astra.models.base import database + from astra.models.source import Source + from astra.migrations.sdss5db.catalogdb import Gaia_dr3_synthetic_photometry_gspc if queue is None: queue = NoQueue() @@ -161,9 +166,8 @@ def migrate_gaia_synthetic_photometry( q = ( Source .select() + .where((Source.gaia_dr3_source_id.is_null(False) & Source.g_sdss_mag.is_null())) ) - if where: - q = q.where(where) q = ( q @@ -269,6 +273,8 @@ def migrate_zhang_stellar_parameters(where=None, batch_size: Optional[int] = 500 """ Migrate stellar parameters derived using Gaia XP spectra from Zhang, Green & Rix (2023) using the cross-match with `catalogid31` (v1). """ + from astra.models.base import database + from astra.models.source import Source from astra.migrations.sdss5db.catalogdb import CatalogdbModel, Gaia_DR3, BigIntegerField, ForeignKeyField @@ -391,6 +397,8 @@ class Meta: def migrate_tic_v8_identifier(catalogid_field_name="catalogid21", batch_size: Optional[int] = 500, limit: Optional[int] = None, queue=None): if queue is None: queue = NoQueue() + from astra.models.base import database + from astra.models.source import Source from astra.migrations.sdss5db.catalogdb import CatalogToTIC_v8 @@ -453,14 +461,6 @@ def migrate_tic_v8_identifier(catalogid_field_name="catalogid21", batch_size: Op def migrate_twomass_photometry( - where=( - ( - Source.j_mag.is_null() - | Source.h_mag.is_null() - | Source.k_mag.is_null() - ) - & Source.catalogid31.is_null(False) - ), limit: Optional[int] = None, batch_size: Optional[int] = 500, queue = None @@ -469,10 +469,22 @@ def migrate_twomass_photometry( Migrate 2MASS photometry from the database, using the cross-match with `catalogid31` (v1). """ + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() from astra.migrations.sdss5db.catalogdb import TwoMassPSC, CatalogToTwoMassPSC + where = ( + ( + Source.j_mag.is_null() + | Source.h_mag.is_null() + | Source.k_mag.is_null() + ) + & Source.catalogid31.is_null(False) + ) + q = ( Source .select(Source.catalogid31) @@ -573,17 +585,10 @@ def migrate_twomass_photometry( def migrate_unwise_photometry( - where=( - ( - Source.w1_flux.is_null() - | Source.w2_flux.is_null() - ) - & Source.catalogid21.is_null(False) - ), catalogid_field_name="catalogid21", batch_size: Optional[int] = 500, limit: Optional[int] = None, - queue = None + queue = None, ): """ Migrate 2MASS photometry from the database, using the cross-match with `catalogid21` (v0). @@ -591,9 +596,14 @@ def migrate_unwise_photometry( As of 2023-09-14, the cross-match does not yield anything with `catalog31`. """ + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() + + from astra.migrations.sdss5db.catalogdb import unWISE, CatalogTounWISE catalogid_field = getattr(Source, catalogid_field_name) @@ -602,13 +612,20 @@ def migrate_unwise_photometry( Source .select( Source.pk, + Source.sdss_id, catalogid_field ) - .where(where) + .where( + ( + Source.w1_flux.is_null() + | Source.w2_flux.is_null() + ) + & Source.catalogid21.is_null(False) + ) .order_by(catalogid_field.asc()) .limit(limit) ) - + updated = 0 queue.put(dict(total=limit or q.count())) if q: @@ -677,6 +694,9 @@ def migrate_glimpse_photometry(catalogid_field_name="catalogid31", batch_size: O Migrate Glimpse photometry from the database, using the cross-match with `catalogid31` (v1). """ + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() @@ -753,20 +773,16 @@ def migrate_glimpse_photometry(catalogid_field_name="catalogid31", batch_size: O def migrate_gaia_source_ids( - where=( - (Source.gaia_dr3_source_id.is_null()) - | (Source.gaia_dr3_source_id == 0) - | (Source.gaia_dr2_source_id.is_null()) - | (Source.gaia_dr2_source_id == 0) - ), limit: Optional[int] = None, - batch_size: Optional[int] = 500, + batch_size: Optional[int] = 1000, queue=None ): """ Migrate Gaia source IDs for anything that we might have missed. """ - + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() @@ -775,11 +791,17 @@ def migrate_gaia_source_ids( q = ( Source .select() - .where(where) + .where( + (Source.gaia_dr3_source_id.is_null()) + | (Source.gaia_dr3_source_id == 0) + | (Source.gaia_dr2_source_id.is_null()) + | (Source.gaia_dr2_source_id == 0) + ) .limit(limit) ) updated = [] + fields = set() queue.put(dict(total=limit or q.count(), description="Querying Gaia source IDs")) for chunk in chunked(q, batch_size): @@ -802,6 +824,8 @@ def migrate_gaia_source_ids( for catalogid, gaia_dr3_source_id in q: source = source_by_catalogid[catalogid] source.gaia_dr3_source_id = gaia_dr3_source_id + if gaia_dr3_source_id is not None: + fields.add(Source.gaia_dr3_source_id) updated.append(source) q = ( @@ -816,10 +840,13 @@ def migrate_gaia_source_ids( for catalogid, gaia_dr2_source_id in q: source = source_by_catalogid[catalogid] source.gaia_dr2_source_id = gaia_dr2_source_id + if gaia_dr2_source_id is not None: + fields.add(Source.gaia_dr2_source_id) updated.append(source) queue.put(dict(advance=batch_size)) + fields = list(fields) n_updated, updated = (0, list(set(updated))) queue.put(dict(total=len(updated), completed=0, description="Ingesting Gaia DR3 source IDs")) integrity_errors = [] @@ -829,18 +856,15 @@ def migrate_gaia_source_ids( Source .bulk_update( chunk, - fields=[ - Source.gaia_dr2_source_id, - Source.gaia_dr3_source_id - ] + fields=fields ) ) except IntegrityError: integrity_errors.append(chunk) - + queue.put(dict(advance=batch_size)) - if integrity_errors: - log.warning(f"Integrity errors encountered for {len(integrity_errors)} chunks") + #if integrity_errors: + # log.warning(f"Integrity errors encountered for {len(integrity_errors)} chunks") queue.put(Ellipsis) return n_updated @@ -848,17 +872,6 @@ def migrate_gaia_source_ids( def migrate_gaia_dr3_astrometry_and_photometry( - where = ( - ( - Source.g_mag.is_null() - | Source.bp_mag.is_null() - | Source.rp_mag.is_null() - ) - & ( - Source.gaia_dr3_source_id.is_null(False) - & (Source.gaia_dr3_source_id > 0) - ) - ), limit: Optional[int] = None, batch_size: Optional[int] = 500, queue=None @@ -874,6 +887,9 @@ def migrate_gaia_dr3_astrometry_and_photometry( :param limit: [optional] Limit the update to `limit` records. Useful for testing. """ + from astra.models.base import database + from astra.models.source import Source + if queue is None: queue = NoQueue() @@ -897,7 +913,17 @@ class Meta: Source .select(Source.gaia_dr3_source_id) .distinct() - .where(where) + .where( + ( + Source.g_mag.is_null() + | Source.bp_mag.is_null() + | Source.rp_mag.is_null() + ) + & ( + Source.gaia_dr3_source_id.is_null(False) + & (Source.gaia_dr3_source_id > 0) + ) + ) .order_by( Source.gaia_dr3_source_id.asc() ) @@ -959,8 +985,6 @@ class Meta: ) ) ) - if where: - q = q.where(where) updated_sources = [] for source in q: @@ -1001,199 +1025,3 @@ class Meta: return updated -def migrate_sources_from_sdss5_catalogdb(batch_size: Optional[int] = 500, limit: Optional[int] = None): - """ - Migrate all catalog sources stored in the SDSS-V database. - - This creates a unique identifier per astronomical source (akin to a `sdss_id`) and links all possible - catalog identifiers (`catalogdb.catalog.catalogid`) to those unique sources. - - :param batch_size: [optional] - The batch size to use when upserting data. - - :param limit: [optional] - Limit the initial catalog queries for testing purposes. - - :returns: - A tuple of new `sdss_id` identifiers created. - """ - raise ProgrammingError - - from astra.migrations.sdss5db.catalogdb import CatalogdbModel - - class Catalog_ver25_to_ver31_full_unique(CatalogdbModel): - - id = IntegerField(primary_key=True) - - class Meta: - table_name = 'catalog_ver25_to_ver31_full_unique' - - log.info(f"Querying catalogdb.catalog_ver25_to_ver31_unique") - - q = ( - Catalog_ver25_to_ver31_full_unique - .select( - Catalog_ver25_to_ver31_full_unique.id, - Catalog_ver25_to_ver31_full_unique.lowest_catalogid, - Catalog_ver25_to_ver31_full_unique.highest_catalogid, - ) - .limit(limit) - .tuples() - .iterator() - ) - - # Sometimes the highest_catalogid appears twice. There's a good reason for this. - # I just don't know what it is. But we need unique-ness, and we need to link to - # the lower catalog identifiers. - next_sdss_id = 1 - source_data, lookup_sdss_id_from_catalog_id, lookup_catalog_id_from_sdss_id = ({}, {}, {}) - for sdss_id, lowest, highest in tqdm(q, total=limit or get_approximate_rows(Catalog_ver25_to_ver31_full_unique)): - - # Do we already have an sdss_id assigned to this highest catalog identifier? - sdss_id_1 = lookup_sdss_id_from_catalog_id.get(highest, None) - sdss_id_2 = lookup_sdss_id_from_catalog_id.get(lowest, None) - - if sdss_id_1 is not None and sdss_id_2 is not None and sdss_id_1 != sdss_id_2: - # We need to amalgamate these two. - affected = [] - affected.extend(lookup_catalog_id_from_sdss_id[sdss_id_1]) - affected.extend(lookup_catalog_id_from_sdss_id[sdss_id_2]) - - # merge both into sdss_id_1 - source_data[sdss_id_1] = dict( - sdss_id=sdss_id_1, - sdss5_catalogid_v1=max(affected) - ) - for catalogid in affected: - lookup_sdss_id_from_catalog_id[catalogid] = sdss_id_1 - - lookup_catalog_id_from_sdss_id[sdss_id_1] = affected - - del source_data[sdss_id_2] - del lookup_catalog_id_from_sdss_id[sdss_id_2] - - else: - sdss_id = sdss_id_1 or sdss_id_2 - if sdss_id is None: - sdss_id = 0 + next_sdss_id - next_sdss_id += 1 - - lookup_catalog_id_from_sdss_id.setdefault(sdss_id, []) - lookup_catalog_id_from_sdss_id[sdss_id].extend((lowest, highest)) - - lookup_sdss_id_from_catalog_id[lowest] = sdss_id - lookup_sdss_id_from_catalog_id[highest] = sdss_id - source_data[sdss_id] = dict( - sdss_id=sdss_id, - sdss5_catalogid_v1=highest - ) - - log.info(f"There are {len(source_data)} unique `sdss_id` entries so far") - - class Catalog_ver25_to_ver31_full_all(CatalogdbModel): - - id = IntegerField(primary_key=True) - - class Meta: - table_name = 'catalog_ver25_to_ver31_full_all' - - log.info(f"Querying catalogdb.catalog_ver25_to_ver31_full_all") - - q = ( - Catalog_ver25_to_ver31_full_all - .select( - Catalog_ver25_to_ver31_full_all.lowest_catalogid, - Catalog_ver25_to_ver31_full_all.highest_catalogid - ) - .limit(limit) - .tuples() - .iterator() - ) - - for lowest, highest in tqdm(q, total=limit or get_approximate_rows(Catalog_ver25_to_ver31_full_all)): - - sdss_id_1 = lookup_sdss_id_from_catalog_id.get(highest, None) - sdss_id_2 = lookup_sdss_id_from_catalog_id.get(lowest, None) - - if sdss_id_1 is not None and sdss_id_2 is not None and sdss_id_1 != sdss_id_2: - # We need to amalgamate these two. - affected = [] - affected.extend(lookup_catalog_id_from_sdss_id[sdss_id_1]) - affected.extend(lookup_catalog_id_from_sdss_id[sdss_id_2]) - - # merge both into sdss_id_1 - source_data[sdss_id_1] = dict( - sdss_id=sdss_id_1, - sdss5_catalogid_v1=max(affected) - ) - for catalogid in affected: - lookup_sdss_id_from_catalog_id[catalogid] = sdss_id_1 - - lookup_catalog_id_from_sdss_id[sdss_id_1] = affected - - del source_data[sdss_id_2] - del lookup_catalog_id_from_sdss_id[sdss_id_2] - - else: - sdss_id = sdss_id_1 or sdss_id_2 - if sdss_id is None: - sdss_id = 0 + next_sdss_id - next_sdss_id += 1 - - lookup_catalog_id_from_sdss_id.setdefault(sdss_id, []) - lookup_catalog_id_from_sdss_id[sdss_id].extend((lowest, highest)) - - lookup_sdss_id_from_catalog_id[lowest] = sdss_id - lookup_sdss_id_from_catalog_id[highest] = sdss_id - source_data[sdss_id] = dict( - sdss_id=sdss_id, - sdss5_catalogid_v1=highest - ) - - log.info(f"There are now {len(source_data)} unique `sdss_id` entries so far") - - # Create the Source - new_source_ids = [] - with database.atomic(): - # Need to chunk this to avoid SQLite limits. - with tqdm(desc="Upserting", unit="sources", total=len(source_data)) as pb: - for chunk in chunked(source_data.values(), batch_size): - new_source_ids.extend( - Source - .insert_many(chunk) - .on_conflict_ignore() - .returning(Source.sdss_id) - .tuples() - .execute() - ) - pb.update(min(batch_size, len(chunk))) - pb.refresh() - - log.info(f"Inserted {len(new_source_ids)} new sources") - - log.info("Linking catalog identifiers to SDSS identifiers") - - data_generator = ( - dict(catalogid=catalogid, sdss_id=sdss_id) - for catalogid, sdss_id in lookup_sdss_id_from_catalog_id.items() - ) - - with database.atomic(): - with tqdm(desc="Linking catalog identifiers to unique sources", total=len(lookup_sdss_id_from_catalog_id)) as pb: - for chunk in chunked(data_generator, batch_size): - ( - SDSSCatalog - .insert_many(chunk) - .on_conflict_ignore() - .returning(SDSSCatalog.catalogid) - .tuples() - .execute() - ) - pb.update(min(batch_size, len(chunk))) - pb.refresh() - - return tuple(new_source_ids) - - - - diff --git a/src/astra/migrations/misc.py b/src/astra/migrations/misc.py index c302e68..80bb5ab 100644 --- a/src/astra/migrations/misc.py +++ b/src/astra/migrations/misc.py @@ -3,15 +3,12 @@ import astropy.coordinates as coord import astropy.units as u from scipy.signal import argrelmin -from peewee import chunked, fn, JOIN +from peewee import chunked, fn, JOIN, EXCLUDED import concurrent.futures import pickle from astra.utils import flatten, expand_path -from astra.models.base import database -from astra.models.source import Source -from astra.models.apogee import ApogeeVisitSpectrum -from astra.models.boss import BossVisitSpectrum + from astra.migrations.utils import NoQueue from astropy.coordinates import SkyCoord from astropy import units as u @@ -23,14 +20,14 @@ von = lambda v: v or np.nan def compute_w1mag_and_w2mag( - where=( - (Source.w1_flux.is_null(False) & Source.w1_mag.is_null(True)) - | (Source.w2_flux.is_null(False) & Source.w2_mag.is_null(True)) - ), limit=None, batch_size=1000, queue=None ): + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum if queue is None: queue = NoQueue() @@ -43,7 +40,10 @@ def compute_w1mag_and_w2mag( Source.w2_flux, Source.w2_dflux, ) - .where(where) + .where( + (Source.w1_flux.is_null(False) & Source.w1_mag.is_null(True)) + | (Source.w2_flux.is_null(False) & Source.w2_mag.is_null(True)) + ) .limit(limit) ) n_updated = 0 @@ -76,12 +76,15 @@ def compute_w1mag_and_w2mag( def update_galactic_coordinates( - where=(Source.ra.is_null(False) & Source.l.is_null(True)), limit=None, frame="icrs", batch_size=1000, queue=None ): + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum if queue is None: queue = NoQueue() @@ -92,7 +95,7 @@ def update_galactic_coordinates( Source.ra, Source.dec ) - .where(where) + .where((Source.ra.is_null(False) & Source.l.is_null(True))) .limit(limit) ) @@ -124,6 +127,12 @@ def update_galactic_coordinates( def fix_unsigned_apogee_flags(queue): + + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum + if queue is None: queue = NoQueue() @@ -158,10 +167,6 @@ def compute_gonzalez_hernandez_irfm_effective_temperatures_from_vmk( model, logg_field, fe_h_field, - where=( - Source.v_jkc_mag.is_null(False) - & Source.k_mag.is_null(False) - ), dwarf_giant_logg_split=3.8, batch_size=10_000 ): @@ -176,7 +181,11 @@ def compute_gonzalez_hernandez_irfm_effective_temperatures_from_vmk( giant_colour_range = [0.7, 3.8] giant_fe_h_range = [-4.0, 0.1] ''' - + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum + B_dwarf = np.array([0.5201, 0.2511, -0.0118, -0.0186, 0.0408, 0.0033]) B_giant = np.array([0.5293, 0.2489, -0.0119, -0.0042, 0.0135, 0.0010]) @@ -198,12 +207,11 @@ def compute_gonzalez_hernandez_irfm_effective_temperatures_from_vmk( (model.v_astra == __version__) & logg_field.is_null(False) & fe_h_field.is_null(False) + & Source.v_jkc_mag.is_null(False) + & Source.k_mag.is_null(False) ) ) - if where: - q = q.where(where) - n_updated, batch = (0, []) for row in tqdm(q.iterator()): X = (row._source.v_jkc_mag or np.nan) - (row._source.k_mag or np.nan) @@ -265,15 +273,15 @@ def compute_gonzalez_hernandez_irfm_effective_temperatures_from_vmk( def compute_casagrande_irfm_effective_temperatures( model, fe_h_field, - where=( - Source.v_jkc_mag.is_null(False) - & Source.k_mag.is_null(False) - ), batch_size=10_000 ): """ Compute IRFM effective temperatures using the V-Ks colour and the Casagrande et al. (2010) scale. """ + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum valid_v_k = [0.78, 3.15] @@ -294,11 +302,12 @@ def compute_casagrande_irfm_effective_temperatures( Source, ) .join(Source, on=(model.source_pk == Source.pk), attr="_source") + .where( + Source.v_jkc_mag.is_null(False) + & Source.k_mag.is_null(False) + ) ) - if where: - q = q.where(where) - n_updated, batch = (0, []) for row in tqdm(q.iterator()): @@ -353,8 +362,14 @@ def update_visit_spectra_counts( apogee_where=None, boss_where=None, batch_size=10_000, - queue=None + queue=None, + k=1000 ): + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum + if queue is None: queue = NoQueue() @@ -384,7 +399,7 @@ def update_visit_spectra_counts( q_apogee_counts = ( ApogeeVisitSpectrum .select( - ApogeeVisitSpectrum.source_pk, + ApogeeVisitSpectrum.source.alias("pk"), fn.count(ApogeeVisitSpectrum.pk).alias("n_apogee_visits"), fn.min(ApogeeVisitSpectrum.mjd).alias("apogee_min_mjd"), fn.max(ApogeeVisitSpectrum.mjd).alias("apogee_max_mjd"), @@ -418,7 +433,7 @@ def update_visit_spectra_counts( q_boss_counts = ( BossVisitSpectrum .select( - BossVisitSpectrum.source_pk, + BossVisitSpectrum.source.alias("pk"), fn.count(BossVisitSpectrum.pk).alias("n_boss_visits"), fn.min(BossVisitSpectrum.mjd).alias("boss_min_mjd"), fn.max(BossVisitSpectrum.mjd).alias("boss_max_mjd"), @@ -429,75 +444,67 @@ def update_visit_spectra_counts( ) # merge counts - defaults = dict( - n_boss_visits=0, - n_apogee_visits=0, - apogee_min_mjd=None, - apogee_max_mjd=None, - boss_min_mjd=None, - boss_max_mjd=None - ) - all_counts = {} - queue.put(dict(total=q_apogee_counts.count(), description="Querying APOGEE visit counts")) - for each in q_apogee_counts.iterator(): - source_pk = each.pop("source") - all_counts[source_pk] = defaults - all_counts[source_pk].update(each) - queue.put(dict(advance=1)) - + update = {} + queue.put(dict(total=q_apogee_counts.count(), completed=0, description="Querying APOGEE visit counts")) + for i, each in enumerate(q_apogee_counts.iterator()): + pk = each["pk"] + if pk is not None and each["apogee_min_mjd"] is not None: + update.setdefault(pk, {}) + update[pk].update(each) + if i > 0 and i % k == 0: + queue.put(dict(advance=k)) + queue.put(dict(total=q_boss_counts.count(), description="Querying BOSS visit counts", completed=0)) - for each in q_boss_counts.iterator(): - source_pk = each.pop("source") - all_counts.setdefault(source_pk, defaults) - all_counts[source_pk].update(each) - queue.put(dict(advance=1)) + for i, each in enumerate(q_boss_counts.iterator()): + pk = each["pk"] + if pk is not None and each["boss_min_mjd"] is not None: + update.setdefault(pk, {}) + update[pk].update(each) + if i > 0 and i % k == 0: + queue.put(dict(advance=k)) - update = [] - queue.put(dict(total=Source.select().count(), description="Collecting source visit counts", completed=0)) - for s in Source.select().iterator(): - for k, v in all_counts.get(s.pk, {}).items(): - setattr(s, k, v) - update.append(s) - queue.put(dict(total=len(update), description="Updating source visit counts", completed=0)) - for batch in chunked(update, batch_size): - # Ugh some issue where if we are only setting Nones for all then if we supply the field it dies - fields = {"n_apogee_visits", "n_boss_visits"} - for b in batch: - if b.apogee_min_mjd is not None: - fields.add("apogee_min_mjd") - if b.apogee_max_mjd is not None: - fields.add("apogee_max_mjd") - if b.boss_min_mjd is not None: - fields.add("boss_min_mjd") - if b.boss_max_mjd is not None: - fields.add("boss_max_mjd") - fields = [getattr(Source, f) for f in fields] + for chunk in chunked(update.values(), batch_size): with database.atomic(): - Source.bulk_update( - batch, - fields=fields + ( + Source + .insert_many(chunk) + .on_conflict( + conflict_target=[Source.pk], + preserve=[ + Source.n_apogee_visits, + Source.n_boss_visits, + Source.apogee_min_mjd, + Source.apogee_max_mjd, + Source.boss_min_mjd, + Source.boss_max_mjd + ], + where=( + (EXCLUDED.n_apogee_visits > Source.n_apogee_visits) + | (EXCLUDED.n_boss_visits > Source.n_boss_visits) + | ((EXCLUDED.n_apogee_visits > 0) & (Source.n_apogee_visits.is_null())) + | ((EXCLUDED.n_boss_visits > 0) & (Source.n_boss_visits.is_null())) + ) + ) + .execute() ) - queue.put(dict(advance=batch_size)) + queue.put(dict(advance=batch_size)) queue.put(Ellipsis) - return len(update) + return None def compute_n_neighborhood( - where=( - ( - Source.n_neighborhood.is_null() - | (Source.n_neighborhood < 0) - ) - & Source.gaia_dr3_source_id.is_null(False) - ), radius=3, # arcseconds brightness=5, # magnitudes batch_size=1000, limit=None, queue=None ): + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum #"Sources within 3\" and G_MAG < G_MAG_source + 5" if queue is None: queue = NoQueue() @@ -514,7 +521,13 @@ class Meta: q = ( Source .select() - .where(where) + .where( + ( + Source.n_neighborhood.is_null() + | (Source.n_neighborhood < 0) + ) + & Source.gaia_dr3_source_id.is_null(False) + ) .limit(limit) ) @@ -554,7 +567,8 @@ class Meta: n_updated += len(batch_update) if len(batch_update) > 0: - Source.bulk_update(batch_update, fields=[Source.n_neighborhood]) + with database.atomic(): + Source.bulk_update(batch_update, fields=[Source.n_neighborhood]) queue.put(dict(advance=batch_size)) queue.put(Ellipsis) @@ -562,6 +576,7 @@ class Meta: def set_missing_gaia_source_ids_to_null(): + from astra.models.source import Source ( Source .update(gaia_dr3_source_id=None) @@ -575,13 +590,7 @@ def set_missing_gaia_source_ids_to_null(): .execute() ) -def compute_f_night_time_for_boss_visits( - where=( - BossVisitSpectrum.f_night_time.is_null() - & BossVisitSpectrum.tai_end.is_null(False) - & BossVisitSpectrum.tai_beg.is_null(False) # sometimes we don't have tai_beg or tai_end - ), - limit=None, batch_size=1000, n_time=256, max_workers=64, queue=None): +def compute_f_night_time_for_boss_visits(limit=None, batch_size=1000, n_time=256, max_workers=64, queue=None): """ Compute `f_night_time`, which is the observation mid-point expressed as a fraction of time between local sunset and sunrise. @@ -600,11 +609,18 @@ def compute_f_night_time_for_boss_visits( :param max_workers: The maximum number of workers to use when computing `f_night_time`. """ - + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum q = ( BossVisitSpectrum .select() - .where(where) + .where( + BossVisitSpectrum.f_night_time.is_null() + & BossVisitSpectrum.tai_end.is_null(False) + & BossVisitSpectrum.tai_beg.is_null(False) # sometimes we don't have tai_beg or tai_end + ) .limit(limit) ) @@ -613,7 +629,7 @@ def compute_f_night_time_for_boss_visits( return _compute_f_night_time_for_visits(q, BossVisitSpectrum, get_obs_time, batch_size, n_time, max_workers, queue) -def compute_f_night_time_for_apogee_visits(where=ApogeeVisitSpectrum.f_night_time.is_null(), limit=None, batch_size=1000, n_time=256, max_workers=64, queue=None): +def compute_f_night_time_for_apogee_visits(limit=None, batch_size=1000, n_time=256, max_workers=64, queue=None): """ Compute `f_night_time`, which is the observation mid-point expressed as a fraction of time between local sunset and sunrise. @@ -632,10 +648,14 @@ def compute_f_night_time_for_apogee_visits(where=ApogeeVisitSpectrum.f_night_tim :param max_workers: The maximum number of workers to use when computing `f_night_time`. """ + from astra.models.base import database + from astra.models.source import Source + from astra.models.apogee import ApogeeVisitSpectrum + from astra.models.boss import BossVisitSpectrum q = ( ApogeeVisitSpectrum .select() - .where(where) + .where(ApogeeVisitSpectrum.f_night_time.is_null()) .limit(limit) ) return _compute_f_night_time_for_visits(q, ApogeeVisitSpectrum, lambda v: v.date_obs, batch_size, n_time, max_workers, queue) @@ -663,6 +683,8 @@ def _compute_f_night_time(pk, observatory, time, n_time): def _compute_f_night_time_for_visits(q, model, get_obs_time, batch_size, n_time, max_workers, queue): + from astra.models.base import database + if queue is None: queue = NoQueue() executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) @@ -679,9 +701,11 @@ def _compute_f_night_time_for_visits(q, model, get_obs_time, batch_size, n_time, futures.append(executor.submit(_compute_f_night_time, visit.pk, observatory, time, n_time)) visit_by_pk[visit.pk] = visit + queue.put(dict(advance=1)) updated = [] n_updated = 0 + queue.put(dict(total=len(futures), description="Updating f_night", completed=0)) for future in concurrent.futures.as_completed(futures): pk, f_night_time = future.result() visit = visit_by_pk[pk] @@ -689,17 +713,6 @@ def _compute_f_night_time_for_visits(q, model, get_obs_time, batch_size, n_time, updated.append(visit) queue.put(dict(advance=1)) - if len(updated) >= batch_size: - with database.atomic(): - n_updated += ( - model - .bulk_update( - updated, - fields=[model.f_night_time], - ) - ) - updated = [] - if len(updated) > 0: with database.atomic(): n_updated += ( @@ -707,7 +720,9 @@ def _compute_f_night_time_for_visits(q, model, get_obs_time, batch_size, n_time, .bulk_update( updated, fields=[model.f_night_time], + batch_size=batch_size ) ) + queue.put(Ellipsis) return n_updated diff --git a/src/astra/migrations/new_apogee.py b/src/astra/migrations/new_apogee.py index 0c857d5..63e590e 100644 --- a/src/astra/migrations/new_apogee.py +++ b/src/astra/migrations/new_apogee.py @@ -61,11 +61,12 @@ def migrate_dithered_metadata( update = [] queue.put(dict(description="Scraping APOGEE visit spectra headers", total=q.count(), completed=0)) with concurrent.futures.ProcessPoolExecutor(max_workers) as executor: - futures, spectra = ({}, {}) - for total, s in enumerate(q.iterator(), start=1): + futures, spectra, total = ({}, {}, 0) + for s in q.iterator(): futures[s.pk] = executor.submit(_migrate_dithered_metadata, s.pk, s.absolute_path) spectra[s.pk] = s queue.put(dict(advance=1)) + total += 1 queue.put(dict(description="Parsing APOGEE visit spectra headers", total=total, completed=0)) for future in concurrent.futures.as_completed(futures.values()): diff --git a/src/astra/models/boss.py b/src/astra/models/boss.py index 52c1c23..269d871 100644 --- a/src/astra/models/boss.py +++ b/src/astra/models/boss.py @@ -165,6 +165,8 @@ def e_flux(self): class Meta: indexes = ( (("release", "run2d", "fieldid", "mjd", "catalogid"), True), + # The folloing index makes it easier to count the number of unique spectra per source (over different reduction versions) + (("source_pk", "telescope", "mjd", "fieldid", "plateid"), False), ) @property