-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #277 from compute-tooling/cache-model-params
Cache model parameters
- Loading branch information
Showing
18 changed files
with
479 additions
and
260 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
from typing import NamedTuple, Type | ||
from webapp.apps.comp.displayer import Displayer | ||
from webapp.apps.comp.model_parameters import ModelParameters | ||
from webapp.apps.comp.parser import Parser | ||
|
||
|
||
class IOClasses(NamedTuple): | ||
displayer: Displayer | ||
model_parameters: ModelParameters | ||
Parser: Type[Parser] | ||
|
||
|
||
def get_ioutils(project, **kwargs): | ||
return IOClasses( | ||
displayer=kwargs.get("Displayer", Displayer)(project), | ||
model_parameters=kwargs.get("ModelParameters", ModelParameters)(project), | ||
Parser=kwargs.get("Parser", Parser), | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Generated by Django 3.0.3 on 2020-04-06 20:47 | ||
|
||
import django.contrib.postgres.fields.jsonb | ||
from django.db import migrations, models | ||
import django.db.models.deletion | ||
import django.utils.timezone | ||
import webapp.apps.comp.models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
("users", "0010_auto_20200319_0854"), | ||
("comp", "0026_auto_20200221_1228"), | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name="ModelConfig", | ||
fields=[ | ||
( | ||
"id", | ||
models.AutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
( | ||
"inputs_version", | ||
models.CharField(choices=[("v1", "Version 1")], max_length=10), | ||
), | ||
( | ||
"model_version", | ||
models.CharField( | ||
blank=True, default=None, max_length=100, null=True | ||
), | ||
), | ||
( | ||
"creation_date", | ||
models.DateTimeField(default=django.utils.timezone.now), | ||
), | ||
( | ||
"meta_parameters_values", | ||
django.contrib.postgres.fields.jsonb.JSONField(null=True), | ||
), | ||
("meta_parameters", webapp.apps.comp.models.JSONField(default=dict)), | ||
("model_parameters", webapp.apps.comp.models.JSONField(default=dict)), | ||
( | ||
"project", | ||
models.ForeignKey( | ||
null=True, | ||
on_delete=django.db.models.deletion.SET_NULL, | ||
related_name="model_configs", | ||
to="users.Project", | ||
), | ||
), | ||
], | ||
), | ||
migrations.AddField( | ||
model_name="inputs", | ||
name="model_config", | ||
field=models.ForeignKey( | ||
null=True, | ||
on_delete=django.db.models.deletion.SET_NULL, | ||
related_name="inputs_instances", | ||
to="comp.ModelConfig", | ||
), | ||
), | ||
migrations.AddConstraint( | ||
model_name="modelconfig", | ||
constraint=models.UniqueConstraint( | ||
fields=("project", "model_version", "meta_parameters_values"), | ||
name="unique_model_config", | ||
), | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import paramtools as pt | ||
|
||
from webapp.apps.comp.models import ModelConfig | ||
from webapp.apps.comp.compute import SyncCompute, JobFailError | ||
from webapp.apps.comp import actions | ||
from webapp.apps.comp.exceptions import AppError | ||
|
||
|
||
import os | ||
import json | ||
|
||
INPUTS = os.path.join(os.path.abspath(os.path.dirname(__file__)), "inputs.json") | ||
|
||
|
||
def pt_factory(classname, defaults): | ||
return type(classname, (pt.Parameters,), {"defaults": defaults}) | ||
|
||
|
||
class ModelParameters: | ||
""" | ||
Handles logic for getting cached model parameters and updating the cache. | ||
""" | ||
|
||
def __init__(self, project: "Project", compute: SyncCompute = None): | ||
self.project = project | ||
self.compute = compute or SyncCompute() | ||
self.config = None | ||
|
||
def defaults(self, init_meta_parameters=None): | ||
# get Parameters class for meta parameters and adjust its values. | ||
meta_param_parser = self.meta_parameters_parser() | ||
meta_param_parser.adjust(init_meta_parameters or {}) | ||
meta_parameters = meta_param_parser.dump() | ||
return { | ||
"model_parameters": self.model_parameters_parser( | ||
meta_param_parser.specification(meta_data=False, serializable=True) | ||
), | ||
"meta_parameters": meta_parameters, | ||
} | ||
|
||
def meta_parameters_parser(self): | ||
res = self.get_inputs() | ||
return pt_factory("MetaParametersParser", res["meta_parameters"])() | ||
|
||
def model_parameters_parser(self, meta_parameters_values=None): | ||
res = self.get_inputs(meta_parameters_values) | ||
# TODO: just return defaults or return the parsers, too? | ||
# model_parameters_parser = {} | ||
# for sect, defaults in res["model_parameters"]: | ||
# model_parameters_parser[sect] = type( | ||
# "Parser", (pt.Parameters), {"defaults": defaults}, | ||
# )() | ||
# return model_parameters_parser | ||
return res["model_parameters"] | ||
|
||
def get_inputs(self, meta_parameters_values=None): | ||
""" | ||
Get cached version of inputs or retrieve new version. | ||
""" | ||
meta_parameters_values = meta_parameters_values or {} | ||
|
||
try: | ||
config = ModelConfig.objects.get( | ||
project=self.project, | ||
model_version=self.project.version, | ||
meta_parameters_values=meta_parameters_values, | ||
) | ||
except ModelConfig.DoesNotExist: | ||
success, result = self.compute.submit_job( | ||
{"meta_param_dict": meta_parameters_values or {}}, | ||
self.project.worker_ext(action=actions.INPUTS), | ||
) | ||
if not success: | ||
raise AppError(meta_parameters_values, result["traceback"]) | ||
|
||
# clean up meta parameters before saving them. | ||
if meta_parameters_values: | ||
mp = pt_factory("MP", result["meta_parameters"])() | ||
mp.adjust(meta_parameters_values) | ||
save_vals = mp.specification(meta_data=False, serializable=True) | ||
else: | ||
save_vals = {} | ||
|
||
config = ModelConfig.objects.create( | ||
project=self.project, | ||
model_version=self.project.version, | ||
meta_parameters_values=save_vals, | ||
meta_parameters=result["meta_parameters"], | ||
model_parameters=result["model_parameters"], | ||
inputs_version="v1", | ||
) | ||
|
||
self.config = config | ||
return { | ||
"meta_parameters": config.meta_parameters, | ||
"model_parameters": config.model_parameters, | ||
} |
Oops, something went wrong.