Skip to content

Commit

Permalink
Replace get_simulation_dataset_ids_by_plan_id with list_simulation_da…
Browse files Browse the repository at this point in the history
…tasets_by_plan_id
  • Loading branch information
David Legg committed Nov 1, 2023
1 parent ac1b33e commit d5da3d1
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 15 deletions.
48 changes: 41 additions & 7 deletions src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict
from typing import List
from copy import deepcopy
from warnings import warn

import arrow

Expand All @@ -12,13 +13,15 @@
from .schemas.api import ApiMissionModelCreate
from .schemas.api import ApiMissionModelRead
from .schemas.api import ApiResourceSampleResults
from .schemas.api import ApiSimulationDatasetRead
from .schemas.client import Activity
from .schemas.client import ActivityPlanCreate
from .schemas.client import ActivityPlanRead
from .schemas.client import CommandDictionaryInfo
from .schemas.client import ExpansionRun
from .schemas.client import ExpansionRule
from .schemas.client import ExpansionSet
from .schemas.client import SimulationDataset
from .schemas.client import ResourceType
from .utils.serialization import postgres_interval_to_microseconds
from .aerie_host import AerieHost
Expand Down Expand Up @@ -355,7 +358,7 @@ def exec_sim_query():
return sim_dataset_id

def get_resource_timelines(self, plan_id: int):
samples = self.get_resource_samples(self.get_simulation_dataset_ids_by_plan_id(plan_id)[0])
samples = self.get_resource_samples(self.list_simulation_datasets_by_plan_id(plan_id)[0].id)
api_resource_timeline = ApiResourceSampleResults.from_dict(samples)
return api_resource_timeline

Expand Down Expand Up @@ -973,27 +976,58 @@ def get_rules_by_type(self) -> Dict[str, List[ExpansionRule]]:
return rules_by_type

def get_simulation_dataset_ids_by_plan_id(self, plan_id: int) -> List[int]:
"""Get the IDs of the simulation datasets generated from a given plan
warn("get_simulation_dataset_ids_by_plan_id is deprecated. "
"Use list_simulation_datasets_by_plan_id instead",
DeprecationWarning,
stacklevel=2)
return [s.id for s in self.list_simulation_datasets_by_plan_id(plan_id)]

# TODO: Change output type to sim dataset
def list_simulation_datasets_by_plan_id(self, plan_id: int) -> List[SimulationDataset]:
"""Get metadata for the simulation datasets generated from a given plan
Args:
plan_id (int): ID of parent plan
Returns:
List[int]: IDs of simulation datasets in descending order
List[SimulationDataset]: Simulation datasets in descending order by ID
"""

# Since GQL will group results by simulation, we have to sort client-side
get_simulation_dataset_query = """
query GetSimulationDatasetId($plan_id: Int!) {
simulation(where: {plan_id: {_eq: $plan_id}}, order_by: { id: desc }, limit: 1) {
simulation_datasets(order_by: { id: desc }) {
id
plan_by_pk(id: $plan_id) {
simulations {
simulation_datasets {
id
simulation_id
dataset_id
offset_from_plan_start
plan_revision
model_revision
simulation_template_revision
simulation_revision
dataset_revision
arguments
simulation_start_time
simulation_end_time
status
reason
canceled
requested_by
requested_at
}
}
}
}
"""
data = self.aerie_host.post_to_graphql(
get_simulation_dataset_query, plan_id=plan_id)
return [d["id"] for d in data[0]["simulation_datasets"]]
result = [SimulationDataset(**d)
for sim in data["simulations"]
for d in sim["simulation_datasets"]]
result.sort(key=lambda s: s.id, reverse=True)
return result

def expand_simulation(
self, simulation_dataset_id: int, expansion_set_id: int
Expand Down
8 changes: 4 additions & 4 deletions src/aerie_cli/commands/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def list_expansion_runs(
client = CommandContext.get_client()

if simulation_dataset_id is None:
simulation_datasets = client.get_simulation_dataset_ids_by_plan_id(
plan_id)
simulation_datasets = [
d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)]
table_caption = f'All runs for Plan ID {plan_id}'
else:
simulation_datasets = [simulation_dataset_id]
Expand Down Expand Up @@ -132,8 +132,8 @@ def list_sequences(
client = CommandContext.get_client()

if simulation_dataset_id is None:
simulation_datasets = client.get_simulation_dataset_ids_by_plan_id(
plan_id)
simulation_datasets = [
d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)]
table_caption = f'All sequences for Plan ID {plan_id}'
else:
simulation_datasets = [simulation_dataset_id]
Expand Down
4 changes: 2 additions & 2 deletions src/aerie_cli/commands/plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def list():
table.add_column("Latest Sim. Dataset ID", no_wrap=True)
table.add_column("Model ID", no_wrap=True)
for activity_plan in resp:
sim_ids = client.get_simulation_dataset_ids_by_plan_id(activity_plan.id)
sim_ids = client.list_simulation_datasets_by_plan_id(activity_plan.id)
if len(sim_ids):
simulation_dataset_id = str(max(sim_ids))
simulation_dataset_id = str(sim_ids[0].id)
else:
simulation_dataset_id = ''

Expand Down
29 changes: 29 additions & 0 deletions src/aerie_cli/schemas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,32 @@ class ApiMissionModelCreate(ApiSerialize):
@define
class ApiMissionModelRead(ApiMissionModelCreate):
id: int


@define
class ApiSimulationDatasetRead(ApiSerialize):
id: int
simulation_id: int
dataset_id: int
offset_from_plan_start: timedelta = field(
converter=convert_to_time_delta
)
plan_revision: int
model_revision: int
simulation_template_revision: int
simulation_revision: int
dataset_revision: int
arguments: Dict[str, Any]
simulation_start_time: Arrow = field(
converter=arrow.get
)
simulation_end_time: Arrow = field(
converter=arrow.get
)
status: str
reason: str
canceled: bool
requested_by: str
requested_at: Arrow = field(
converter=arrow.get
)
43 changes: 43 additions & 0 deletions src/aerie_cli/schemas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aerie_cli.schemas.api import ApiResourceSampleResults
from aerie_cli.schemas.api import ApiSimulatedResourceSample
from aerie_cli.schemas.api import ApiSimulationResults
from aerie_cli.schemas.api import ApiSimulationDatasetRead
from aerie_cli.schemas.api import ActivityBase

def parse_timedelta_str_converter(t) -> timedelta:
Expand Down Expand Up @@ -370,3 +371,45 @@ class ExpansionRule(ClientSerialize):
class ResourceType(ClientSerialize):
name: str
schema: Dict

@define
class SimulationDataset(ClientSerialize):
id: int
simulation_id: int
dataset_id: int
offset_from_plan_start: timedelta
plan_revision: int
model_revision: int
simulation_template_revision: int
simulation_revision: int
dataset_revision: int
arguments: Dict[str, Any]
simulation_start_time: Arrow
simulation_end_time: Arrow
status: str
reason: str
canceled: bool
requested_by: str
requested_at: Arrow

@classmethod
def from_api_read(cls, api_sim_dataset: ApiSimulationDatasetRead) -> "SimulationDataset":
return SimulationDataset(
id=api_sim_dataset.id,
simulation_id=api_sim_dataset.simulation_id,
dataset_id=api_sim_dataset.dataset_id,
offset_from_plan_start=api_sim_dataset.offset_from_plan_start,
plan_revision=api_sim_dataset.plan_revision,
model_revision=api_sim_dataset.model_revision,
simulation_template_revision=api_sim_dataset.simulation_template_revision,
simulation_revision=api_sim_dataset.simulation_revision,
dataset_revision=api_sim_dataset.dataset_revision,
arguments=api_sim_dataset.arguments,
simulation_start_time=api_sim_dataset.simulation_start_time,
simulation_end_time=api_sim_dataset.simulation_end_time,
status=api_sim_dataset.status,
reason=api_sim_dataset.reason,
canceled=api_sim_dataset.canceled,
requested_by=api_sim_dataset.requested_by,
requested_at=api_sim_dataset.requested_at
)
4 changes: 2 additions & 2 deletions tests/integration_tests/test_plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def test_delete_collaborators():

def test_plan_simulate():
result = cli_plan_simulate()
sim_ids = client.get_simulation_dataset_ids_by_plan_id(plan_id)
sim_ids = client.list_simulation_datasets_by_plan_id(plan_id)
global sim_id
sim_id = sim_ids[-1]
sim_id = sim_ids[0].id
assert result.exit_code == 0,\
f"{result.stdout}"\
f"{result.stderr}"
Expand Down

0 comments on commit d5da3d1

Please sign in to comment.