Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OO-ify the code generation code #168

Merged
merged 16 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ extend-ignore = E203,E501,E701

per-file-ignores =
# Ignore undefined names in templates.
*/templates/no-tests/*.py:F821,F401,E302
*/code_generators/no-tests/*.py:F821,F401,E302
2 changes: 1 addition & 1 deletion .pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
filterwarnings =
error

addopts = --doctest-glob '*.md' --doctest-modules --ignore dp_wizard/utils/templates/no-tests --ignore dp_wizard/tests/fixtures/ --tracing=retain-on-failure
addopts = --doctest-glob '*.md' --doctest-modules --ignore dp_wizard/utils/code_generators/no-tests --ignore dp_wizard/tests/fixtures/ --tracing=retain-on-failure

# If an xfail starts passing unexpectedly, that should count as a failure:
xfail_strict=true
2 changes: 1 addition & 1 deletion dp_wizard/app/analysis_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dp_wizard.app.components.column_module import column_ui, column_server
from dp_wizard.utils.csv_helper import read_csv_ids_labels, read_csv_ids_names
from dp_wizard.app.components.outputs import output_code_sample, demo_tooltip
from dp_wizard.utils.templates import make_privacy_loss_block
from dp_wizard.utils.code_generators import make_privacy_loss_block
from dp_wizard.app.components.column_module import col_widths


Expand Down
2 changes: 1 addition & 1 deletion dp_wizard/app/components/column_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dp_wizard.utils.dp_helper import make_confidence_accuracy_histogram
from dp_wizard.utils.shared import plot_histogram
from dp_wizard.utils.templates import make_column_config_block
from dp_wizard.utils.code_generators import make_column_config_block
from dp_wizard.app.components.outputs import output_code_sample, demo_tooltip


Expand Down
2 changes: 1 addition & 1 deletion dp_wizard/app/dataset_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dp_wizard.utils.argparse_helpers import get_cli_info
from dp_wizard.app.components.outputs import output_code_sample, demo_tooltip
from dp_wizard.utils.templates import make_privacy_unit_block
from dp_wizard.utils.code_generators import make_privacy_unit_block


def dataset_ui():
Expand Down
75 changes: 20 additions & 55 deletions dp_wizard/app/results_panel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from json import dumps

from shiny import ui, render, reactive

from dp_wizard.utils.templates import make_notebook_py, make_script_py
from dp_wizard.utils.code_generators import (
NotebookGenerator,
ScriptGenerator,
AnalysisPlan,
AnalysisPlanColumn,
)
from dp_wizard.utils.converters import convert_py_to_nb


Expand Down Expand Up @@ -35,77 +38,39 @@ def results_server(
epsilon,
): # pragma: no cover
@reactive.calc
def analysis_dict():
def analysis_plan() -> AnalysisPlan:
# weights().keys() will reflect the desired columns:
# The others retain inactive columns, so user
# inputs aren't lost when toggling checkboxes.
columns = {
col: {
"lower_bound": lower_bounds()[col],
"upper_bound": upper_bounds()[col],
"bin_count": int(bin_counts()[col]),
# TODO: Floats should work for weight, but they don't:
# https://github.com/opendp/opendp/issues/2140
"weight": int(weights()[col]),
}
col: AnalysisPlanColumn(
lower_bound=lower_bounds()[col],
upper_bound=upper_bounds()[col],
bin_count=int(bin_counts()[col]),
weight=int(weights()[col]),
)
for col in weights().keys()
}
return {
"csv_path": csv_path(),
"contributions": contributions(),
"epsilon": epsilon(),
"columns": columns,
}

@reactive.calc
def analysis_json():
return dumps(
analysis_dict(),
indent=2,
return AnalysisPlan(
csv_path=csv_path(),
contributions=contributions(),
epsilon=epsilon(),
columns=columns,
)

@render.text
def analysis_json_text():
return analysis_json()

@reactive.calc
def analysis_python():
analysis = analysis_dict()
return make_notebook_py(
csv_path=analysis["csv_path"],
contributions=analysis["contributions"],
epsilon=analysis["epsilon"],
columns=analysis["columns"],
)

@render.text
def analysis_python_text():
return analysis_python()

@render.download(
filename="dp-wizard-script.py",
media_type="text/x-python",
)
async def download_script():
analysis = analysis_dict()
script_py = make_script_py(
contributions=analysis["contributions"],
epsilon=analysis["epsilon"],
columns=analysis["columns"],
)
script_py = ScriptGenerator(analysis_plan()).make_py()
yield script_py

@render.download(
filename="dp-wizard-notebook.ipynb",
media_type="application/x-ipynb+json",
)
async def download_notebook():
analysis = analysis_dict()
notebook_py = make_notebook_py(
csv_path=analysis["csv_path"],
contributions=analysis["contributions"],
epsilon=analysis["epsilon"],
columns=analysis["columns"],
)
notebook_py = NotebookGenerator(analysis_plan()).make_py()
notebook_nb = convert_py_to_nb(notebook_py, execute=True)
yield notebook_nb
204 changes: 204 additions & 0 deletions dp_wizard/utils/code_generators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import NamedTuple
from abc import ABC, abstractmethod
from pathlib import Path
import re
from dp_wizard.utils.csv_helper import name_to_identifier
from dp_wizard.utils.code_generators._template import Template


class AnalysisPlanColumn(NamedTuple):
lower_bound: float
upper_bound: float
bin_count: int
weight: int


class AnalysisPlan(NamedTuple):
csv_path: str
contributions: int
epsilon: float
columns: dict[str, AnalysisPlanColumn]


class _CodeGenerator(ABC):
def __init__(self, analysis_plan):
self.csv_path = analysis_plan.csv_path
self.contributions = analysis_plan.contributions
self.epsilon = analysis_plan.epsilon
self.columns = analysis_plan.columns

@abstractmethod
def _make_context(self): ... # pragma: no cover

def make_py(self):
return str(
Template(self.root_template).fill_blocks(
IMPORTS_BLOCK=_make_imports(),
COLUMNS_BLOCK=self._make_columns(self.columns),
CONTEXT_BLOCK=self._make_context(),
QUERIES_BLOCK=self._make_queries(self.columns.keys()),
)
)

def _make_margins_dict(self, bin_names):
# TODO: Don't worry too much about the formatting here.
# Plan to run the output through black for consistency.
# https://github.com/opendp/dp-creator-ii/issues/50
margins = (
[
"""
(): dp.polars.Margin(
public_info="lengths",
),"""
]
+ [
f"""
("{bin_name}",): dp.polars.Margin(
public_info="keys",
),"""
for bin_name in bin_names
]
)

margins_dict = "{" + "".join(margins) + "\n }"
return margins_dict

def _make_columns(self, columns):
return "\n".join(
make_column_config_block(
name=name,
lower_bound=col.lower_bound,
upper_bound=col.upper_bound,
bin_count=col.bin_count,
)
for name, col in columns.items()
)

def _make_queries(self, column_names):
return "confidence = 0.95\n\n" + "\n".join(
_make_query(column_name) for column_name in column_names
)

def _make_partial_context(self):
weights = [column.weight for column in self.columns.values()]
column_names = [name_to_identifier(name) for name in self.columns.keys()]
privacy_unit_block = make_privacy_unit_block(self.contributions)
privacy_loss_block = make_privacy_loss_block(self.epsilon)
margins_dict = self._make_margins_dict([f"{name}_bin" for name in column_names])
columns = ", ".join([f"{name}_config" for name in column_names])
return (
Template("context")
.fill_expressions(
MARGINS_DICT=margins_dict,
COLUMNS=columns,
)
.fill_values(
WEIGHTS=weights,
)
.fill_blocks(
PRIVACY_UNIT_BLOCK=privacy_unit_block,
PRIVACY_LOSS_BLOCK=privacy_loss_block,
)
)


class NotebookGenerator(_CodeGenerator):
root_template = "notebook"

def _make_context(self):
return str(self._make_partial_context().fill_values(CSV_PATH=self.csv_path))


class ScriptGenerator(_CodeGenerator):
root_template = "script"

def _make_context(self):
return str(self._make_partial_context().fill_expressions(CSV_PATH="csv_path"))


# Public functions used to generate code snippets in the UI;
# These do not require an entire analysis plan, so they stand on their own.


def make_privacy_unit_block(contributions):
return str(Template("privacy_unit").fill_values(CONTRIBUTIONS=contributions))


def make_privacy_loss_block(epsilon):
return str(Template("privacy_loss").fill_values(EPSILON=epsilon))


def make_column_config_block(name, lower_bound, upper_bound, bin_count):
"""
>>> print(make_column_config_block(
... name="HW GRADE",
... lower_bound=0,
... upper_bound=100,
... bin_count=10
... ))
# From the public information, determine the bins for 'HW GRADE':
hw_grade_cut_points = make_cut_points(
lower_bound=0,
upper_bound=100,
bin_count=10,
)
<BLANKLINE>
# Use these bins to define a Polars column:
hw_grade_config = (
pl.col('HW GRADE')
.cut(hw_grade_cut_points)
.alias('hw_grade_bin') # Give the new column a name.
.cast(pl.String)
)
<BLANKLINE>
"""
snake_name = _snake_case(name)
return str(
Template("column_config")
.fill_expressions(
CUT_LIST_NAME=f"{snake_name}_cut_points",
POLARS_CONFIG_NAME=f"{snake_name}_config",
)
.fill_values(
LOWER_BOUND=lower_bound,
UPPER_BOUND=upper_bound,
BIN_COUNT=bin_count,
COLUMN_NAME=name,
BIN_COLUMN_NAME=f"{snake_name}_bin",
)
)


# Private helper functions:
# These do not depend on the AnalysisPlan,
# so it's better to keep them out of the class.


def _make_query(column_name):
indentifier = name_to_identifier(column_name)
return str(
Template("query")
.fill_values(
BIN_NAME=f"{indentifier}_bin",
)
.fill_expressions(
QUERY_NAME=f"{indentifier}_query",
ACCURACY_NAME=f"{indentifier}_accuracy",
HISTOGRAM_NAME=f"{indentifier}_histogram",
)
)


def _snake_case(name: str):
"""
>>> _snake_case("HW GRADE")
'hw_grade'
"""
return re.sub(r"\W+", "_", name.lower())


def _make_imports():
return (
str(Template("imports").fill_values())
+ (Path(__file__).parent.parent / "shared.py").read_text()
)
Loading
Loading