Skip to content

Commit

Permalink
Merge pull request #277 from compute-tooling/cache-model-params
Browse files Browse the repository at this point in the history
Cache model parameters
  • Loading branch information
hdoupe authored Apr 7, 2020
2 parents 0e952c4 + 0781473 commit 48f4176
Show file tree
Hide file tree
Showing 18 changed files with 479 additions and 260 deletions.
14 changes: 10 additions & 4 deletions webapp/apps/comp/asyncsubmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rest_framework import status
from rest_framework.response import Response

import paramtools as pt

from webapp.apps.users.models import Project

from webapp.apps.comp import actions
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
self.ioutils = ioutils
self.compute = compute
self.badpost = None
self.meta_parameters = ioutils.displayer.parsed_meta_parameters()
self.meta_parameters = ioutils.model_parameters.meta_parameters_parser()
self.sim = sim

def submit(self):
Expand All @@ -63,17 +65,20 @@ def submit(self):
parent_sim = None

try:
self.valid_meta_params = self.meta_parameters.validate(meta_parameters)
self.meta_parameters.adjust(meta_parameters)
self.valid_meta_params = self.meta_parameters.specification(
meta_data=False, serializable=True
)
errors = None
except ValidationError as ve:
except pt.ValidationError as ve:
errors = str(ve)

if errors:
raise BadPostException(errors)

parser = self.ioutils.Parser(
self.project,
self.ioutils.displayer,
self.ioutils.model_parameters,
adjustment,
compute=self.compute,
**self.valid_meta_params,
Expand All @@ -88,6 +93,7 @@ def submit(self):
job_id=result["job_id"],
status="PENDING",
parent_sim=self.sim.parent_sim or parent_sim,
model_config=self.ioutils.model_parameters.config,
)
# case where parent sim exists and has not yet been assigned
if not self.sim.parent_sim and parent_sim:
Expand Down
42 changes: 0 additions & 42 deletions webapp/apps/comp/displayer.py

This file was deleted.

6 changes: 3 additions & 3 deletions webapp/apps/comp/ioutils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import NamedTuple, Type
from webapp.apps.comp.displayer import Displayer
from webapp.apps.comp.model_parameters import ModelParameters
from webapp.apps.comp.parser import Parser


class IOClasses(NamedTuple):
displayer: Displayer
model_parameters: ModelParameters
Parser: Type[Parser]


def get_ioutils(project, **kwargs):
return IOClasses(
displayer=kwargs.get("Displayer", Displayer)(project),
model_parameters=kwargs.get("ModelParameters", ModelParameters)(project),
Parser=kwargs.get("Parser", Parser),
)
95 changes: 0 additions & 95 deletions webapp/apps/comp/meta_parameters.py

This file was deleted.

78 changes: 78 additions & 0 deletions webapp/apps/comp/migrations/0027_auto_20200406_1547.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Generated by Django 3.0.3 on 2020-04-06 20:47

import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
import webapp.apps.comp.models


class Migration(migrations.Migration):

dependencies = [
("users", "0010_auto_20200319_0854"),
("comp", "0026_auto_20200221_1228"),
]

operations = [
migrations.CreateModel(
name="ModelConfig",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"inputs_version",
models.CharField(choices=[("v1", "Version 1")], max_length=10),
),
(
"model_version",
models.CharField(
blank=True, default=None, max_length=100, null=True
),
),
(
"creation_date",
models.DateTimeField(default=django.utils.timezone.now),
),
(
"meta_parameters_values",
django.contrib.postgres.fields.jsonb.JSONField(null=True),
),
("meta_parameters", webapp.apps.comp.models.JSONField(default=dict)),
("model_parameters", webapp.apps.comp.models.JSONField(default=dict)),
(
"project",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="model_configs",
to="users.Project",
),
),
],
),
migrations.AddField(
model_name="inputs",
name="model_config",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="inputs_instances",
to="comp.ModelConfig",
),
),
migrations.AddConstraint(
model_name="modelconfig",
constraint=models.UniqueConstraint(
fields=("project", "model_version", "meta_parameters_values"),
name="unique_model_config",
),
),
]
97 changes: 97 additions & 0 deletions webapp/apps/comp/model_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import paramtools as pt

from webapp.apps.comp.models import ModelConfig
from webapp.apps.comp.compute import SyncCompute, JobFailError
from webapp.apps.comp import actions
from webapp.apps.comp.exceptions import AppError


import os
import json

INPUTS = os.path.join(os.path.abspath(os.path.dirname(__file__)), "inputs.json")


def pt_factory(classname, defaults):
return type(classname, (pt.Parameters,), {"defaults": defaults})


class ModelParameters:
"""
Handles logic for getting cached model parameters and updating the cache.
"""

def __init__(self, project: "Project", compute: SyncCompute = None):
self.project = project
self.compute = compute or SyncCompute()
self.config = None

def defaults(self, init_meta_parameters=None):
# get Parameters class for meta parameters and adjust its values.
meta_param_parser = self.meta_parameters_parser()
meta_param_parser.adjust(init_meta_parameters or {})
meta_parameters = meta_param_parser.dump()
return {
"model_parameters": self.model_parameters_parser(
meta_param_parser.specification(meta_data=False, serializable=True)
),
"meta_parameters": meta_parameters,
}

def meta_parameters_parser(self):
res = self.get_inputs()
return pt_factory("MetaParametersParser", res["meta_parameters"])()

def model_parameters_parser(self, meta_parameters_values=None):
res = self.get_inputs(meta_parameters_values)
# TODO: just return defaults or return the parsers, too?
# model_parameters_parser = {}
# for sect, defaults in res["model_parameters"]:
# model_parameters_parser[sect] = type(
# "Parser", (pt.Parameters), {"defaults": defaults},
# )()
# return model_parameters_parser
return res["model_parameters"]

def get_inputs(self, meta_parameters_values=None):
"""
Get cached version of inputs or retrieve new version.
"""
meta_parameters_values = meta_parameters_values or {}

try:
config = ModelConfig.objects.get(
project=self.project,
model_version=self.project.version,
meta_parameters_values=meta_parameters_values,
)
except ModelConfig.DoesNotExist:
success, result = self.compute.submit_job(
{"meta_param_dict": meta_parameters_values or {}},
self.project.worker_ext(action=actions.INPUTS),
)
if not success:
raise AppError(meta_parameters_values, result["traceback"])

# clean up meta parameters before saving them.
if meta_parameters_values:
mp = pt_factory("MP", result["meta_parameters"])()
mp.adjust(meta_parameters_values)
save_vals = mp.specification(meta_data=False, serializable=True)
else:
save_vals = {}

config = ModelConfig.objects.create(
project=self.project,
model_version=self.project.version,
meta_parameters_values=save_vals,
meta_parameters=result["meta_parameters"],
model_parameters=result["model_parameters"],
inputs_version="v1",
)

self.config = config
return {
"meta_parameters": config.meta_parameters,
"model_parameters": config.model_parameters,
}
Loading

0 comments on commit 48f4176

Please sign in to comment.