Skip to content

Commit

Permalink
add update code to tile-based load
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Oct 10, 2024
1 parent 8bd51ce commit b61e658
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 129 deletions.
108 changes: 8 additions & 100 deletions doc/nb/TestNewTile.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@
],
"source": [
"load_tiles = list()\n",
"load_exposures = list() \n",
"load_exposures = list()\n",
"for new_tile in (load_new_tiles + load_updated_tiles):\n",
" row_index = np.where((update_exposures_table['TILEID'] == new_tile.tileid) & (update_exposures_table['EFFTIME_SPEC'] > 0))[0]\n",
" if len(row_index) > 0:\n",
Expand All @@ -664,8 +664,8 @@
" print(\"ERROR: No valid exposures found for tile {0:d}, even though EFFTIME_SPEC == {1:f}!\".format(new_tile.tileid, new_tile.efftime_spec))\n",
" bad_index = np.where((update_exposures_table['TILEID'] == new_tile.tileid))[0]\n",
" print(update_exposures_table[['EXPID', 'NIGHT', 'MJD', 'EFFTIME_SPEC']][bad_index])\n",
" bad_tiles.append(new_tile)\n",
"load_tiles, load_exposures"
" # bad_tiles.append(new_tile)\n",
"# load_tiles, load_exposures"
]
},
{
Expand All @@ -684,44 +684,6 @@
" load_frames += db.Frame.convert(update_frames_table, row_index=row_index)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "5104bb5d-2706-4f51-ad72-2c5989a95424",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from sqlalchemy.dialects.postgresql import insert"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "1f09208d-8c25-4758-9e15-24f69166f76e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INSERT INTO daily.version (id, package, version) VALUES (%(id_m0)s, %(package_m0)s, %(version_m0)s), (%(id_m1)s, %(package_m1)s, %(version_m1)s), (%(id_m2)s, %(package_m2)s, %(version_m2)s) ON CONFLICT (id) DO UPDATE SET package = excluded.package, version = excluded.version\n"
]
}
],
"source": [
"stmt = insert(db.Version).values([{'id': 11, 'package': 'foo', 'version': '1.2.6'}, {'id': 10, 'package': 'tiles', 'version': 'main.test6'}, {'id': 12, 'package': 'bar', 'version': '3.4.5'}])\n",
"# stmt = stmt.on_conflict_do_update(index_elements=[db.Version.id], set_=dict(package=getattr(stmt.excluded, 'package'), version=getattr(stmt.excluded, 'version')))\n",
"stmt = stmt.on_conflict_do_update(index_elements=[db.Version.id], set_=dict([(c, getattr(stmt.excluded, c.name)) for c in db.Version.__table__.columns if c.name != 'id']))\n",
"print(stmt)\n",
"# db.dbSession.rollback()\n",
"db.dbSession.execute(stmt)\n",
"db.dbSession.commit()"
]
},
{
"cell_type": "code",
"execution_count": 76,
Expand All @@ -731,65 +693,12 @@
},
"outputs": [],
"source": [
"load_tiles_as_dict = list()\n",
"for t in load_tiles:\n",
" tt = t.__dict__.copy()\n",
" del tt['_sa_instance_state']\n",
" load_tiles_as_dict.append(tt)\n",
"stmt = insert(db.Tile).values(load_tiles_as_dict)\n",
"stmt = stmt.on_conflict_do_update(index_elements=[db.Tile.tileid], set_=dict([(c, getattr(stmt.excluded, c.name)) for c in db.Tile.__table__.columns if c.name != 'tileid']))\n",
"stmt = db.upsert(load_tiles)\n",
"# print(stmt)\n",
"db.dbSession.execute(stmt)\n",
"db.dbSession.commit()"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "32163199-fec8-4e03-bf18-9fc21c68e6f8",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(<sqlalchemy.orm.attributes.InstrumentedAttribute at 0x7f9bda1a6520>,\n",
" Column('expid', Integer(), table=<exposure>, primary_key=True, nullable=False))"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"getattr(load_exposures[0].__class__, 'expid'), [c for c in load_exposures[0].__class__.__table__.columns if c.primary_key][0]"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "56a310b4-fb2e-4a40-88ed-2b9609379422",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def upsert(rows):\n",
" \"\"\"Convert a list of ORM objects into an 'UPSERT' statement.\n",
" \"\"\"\n",
" cls = rows[0].__class__\n",
" pk = [c for c in cls.__table__.columns if c.primary_key][0]\n",
" inserts = list()\n",
" for row in rows:\n",
" rr = row.__dict__.copy()\n",
" del rr['_sa_instance_state']\n",
" inserts.append(rr)\n",
" stmt = insert(cls).values(inserts)\n",
" stmt = stmt.on_conflict_do_update(index_elements=[getattr(cls, pk.name)], set_=dict([(c, getattr(stmt.excluded, c.name)) for c in cls.__table__.columns if c.name != pk.name]))\n",
" return stmt"
]
},
{
"cell_type": "code",
"execution_count": 90,
Expand All @@ -807,8 +716,8 @@
}
],
"source": [
"stmt = upsert(load_exposures)\n",
"print(stmt)\n",
"stmt = db.upsert(load_exposures)\n",
"# print(stmt)\n",
"db.dbSession.execute(stmt)\n",
"db.dbSession.commit()"
]
Expand All @@ -822,8 +731,7 @@
},
"outputs": [],
"source": [
"stmt = upsert(load_frames)\n",
"# db.dbSession.rollback()\n",
"stmt = db.upsert(load_frames)\n",
"# print(stmt)\n",
"db.dbSession.execute(stmt)\n",
"db.dbSession.commit()"
Expand Down
41 changes: 36 additions & 5 deletions py/specprodDB/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@

from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import (create_engine, event, ForeignKey, Column, DDL,
BigInteger, Boolean, Integer, String, Float, DateTime,
SmallInteger, bindparam, Numeric, and_, text)
from sqlalchemy.sql import func
from sqlalchemy.exc import IntegrityError, ProgrammingError
BigInteger, Boolean, Integer, String, DateTime,
SmallInteger, Numeric, text)
from sqlalchemy.orm import (DeclarativeBase, declarative_mixin, declared_attr,
scoped_session, sessionmaker, relationship)
from sqlalchemy.schema import CreateSchema, Index
from sqlalchemy.schema import Index
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, REAL
from sqlalchemy.dialects.postgresql import insert as pg_insert

from desiutil import __version__ as desiutil_version
from desiutil.iers import freeze_iers
Expand Down Expand Up @@ -1224,6 +1223,33 @@ def convert(cls, data, survey=None, program=None, tileid=None, night=None,
return [cls(**(dict([(col.name, dat) for col, dat in zip(cls.__table__.columns, row)]))) for row in data_rows]


def upsert(rows):
"""Convert a list of ORM objects into an ``INSERT ... ON CONFLICT`` statement.
Parameters
----------
rows : :class:`list`
A list of ORM objects. All items should be the same type.
Returns
-------
:class:`~sqlalchemy.dialects.postgresql.Insert`
A specialzed INSERT statement ready for execution.
"""
cls = rows[0].__class__
pk = [c for c in cls.__table__.columns if c.primary_key][0]
inserts = list()
for row in rows:
rr = row.__dict__.copy()
del rr['_sa_instance_state']
inserts.append(rr)
stmt = pg_insert(cls).values(inserts)
stmt = stmt.on_conflict_do_update(index_elements=[getattr(cls, pk.name)],
set_=dict([(c, getattr(stmt.excluded, c.name))
for c in cls.__table__.columns if c.name != pk.name]))
return stmt


def deduplicate_targetid(data):
"""Find targetphot rows that are not already loaded into the Photometry
table *and* resolve any duplicate TARGETID.
Expand Down Expand Up @@ -1739,4 +1765,9 @@ def main():
log.info("Finished loading %s.", tn)
if options.load == 'fiberassign':
log.info("Consider running VACUUM FULL VERBOSE ANALYZE at this point.")
#
# Clean up.
#
dbSession.close()
engine.dispose()
return 0
61 changes: 37 additions & 24 deletions py/specprodDB/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def load_redshift(tile, spgrp='cumulative'):
tile.tileid, firstnight,
row_index=row_index)
if len(load_ztile) > 0:
db.dbSession.add_all(load_ztile)
statement = db.upsert(load_ztile)
db.dbSession.execute(statement)
# db.dbSession.add_all(load_ztile)
db.dbSession.commit()
db.log.info("Loaded %d rows of Ztile data.", len(load_ztile))
else:
Expand Down Expand Up @@ -436,6 +438,8 @@ def get_options(description="Load data for one tile into a specprod database."):
help='Update primary redshift values and indexes for all tiles.')
prsr.add_argument('-t', '--tiles-file', action='store', dest='tiles_file', metavar='FILE',
help='Override the top-level tiles file associated with a specprod.')
prsr.add_argument('-u', '--update', action='store_true', dest='update',
help='Specify that this is an update to an already-loaded tile.')
prsr.add_argument('tile', metavar='TILEID', type=int, help='Load TILEID.')
options = prsr.parse_args()
return options
Expand Down Expand Up @@ -509,10 +513,10 @@ def main():
# Find the tile in the top-level tiles file.
#
if options.tiles_file is None:
tiles_file = findfile('tiles', readonly=True).replace('.fits', '.csv')
tiles_file = findfile('tiles', readonly=True)
else:
tiles_file = options.tiles_file
tiles_table = Table.read(tiles_file, format='ascii.csv')
tiles_table = Table.read(tiles_file, format='fits', hdu='TILES')
row_index = np.where(tiles_table['TILEID'] == options.tile)[0]
if len(row_index) == 1:
candidate_tiles = db.Tile.convert(tiles_table, row_index=row_index)
Expand Down Expand Up @@ -546,7 +550,9 @@ def main():
assert len(row_index) > 0
load_frames += db.Frame.convert(frames_table, row_index=row_index)
try:
db.dbSession.add_all(candidate_tiles)
statement = db.upsert(candidate_tiles)
db.dbSession.execute(statement)
# db.dbSession.add_all(candidate_tiles)
db.dbSession.commit()
except IntegrityError as exc:
#
Expand All @@ -557,41 +563,46 @@ def main():
db.log.critical("Message was: %s", exc.args[0])
db.dbSession.rollback()
return 1
new_tile = candidate_tiles[0]
try:
db.dbSession.add_all(load_exposures)
statement = db.upsert(load_exposures)
db.dbSession.execute(statement)
# db.dbSession.add_all(load_exposures)
db.dbSession.commit()
except IntegrityError as exc:
db.log.critical("Exposures for tile %d cannot be loaded!", candidate_tiles[0].tileid)
db.log.critical("Exposures for tile %d cannot be loaded!", new_tile.tileid)
db.log.critical("Message was: %s", exc.args[0])
db.dbSession.rollback()
db.dbSession.delete(candidate_tiles[0])
db.dbSession.delete(new_tile)
db.dbSession.commit()
return 1
db.dbSession.add_all(load_frames)
statement = db.upsert(load_frames)
db.dbSession.execute(statement)
# db.dbSession.add_all(load_frames)
db.dbSession.commit()
#
# Load photometry.
#
new_tile = candidate_tiles[0]
potential_targets_table = potential_targets(new_tile.tileid)
potential_cat = potential_photometry(new_tile, potential_targets_table)
potential_targetphot = targetphot(potential_cat)
potential_tractorphot = tractorphot(potential_cat)
loaded_photometry = load_photometry(potential_tractorphot)
loaded_targetphot = load_targetphot(potential_targetphot, loaded_photometry)
# Load photometry. If this is an update, these should already be loaded.
#
# Load targeting table.
#
loaded_target = load_target(new_tile, potential_targetphot)
if not options.update:
potential_targets_table = potential_targets(new_tile.tileid)
potential_cat = potential_photometry(new_tile, potential_targets_table)
potential_targetphot = targetphot(potential_cat)
potential_tractorphot = tractorphot(potential_cat)
loaded_photometry = load_photometry(potential_tractorphot)
loaded_targetphot = load_targetphot(potential_targetphot, loaded_photometry)
#
# Load targeting table.
#
loaded_target = load_target(new_tile, potential_targetphot)
#
# Load fiberassign and potential.
#
loaded_fiberassign, loaded_potential = load_fiberassign(new_tile)
#
# Load tile/cumulative redshifts.
#
loaded_ztile = load_redshift(new_tile)
#
# Load fiberassign and potential.
#
loaded_fiberassign, loaded_potential = load_fiberassign(new_tile)
#
# Update global values, if requested.
#
if options.primary:
Expand All @@ -600,4 +611,6 @@ def main():
#
# Clean up.
#
db.dbSession.close()
db.engine.dispose()
return 0

0 comments on commit b61e658

Please sign in to comment.