Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add household AI endpoint #453

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
- name: Test the API
run: make test
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
AUTH0_ADDRESS_NO_DOMAIN: ${{ secrets.AUTH0_ADDRESS_NO_DOMAIN }}
AUTH0_AUDIENCE_NO_DOMAIN: ${{ secrets.AUTH0_AUDIENCE_NO_DOMAIN }}
AUTH0_TEST_TOKEN_NO_DOMAIN: ${{ secrets.AUTH0_TEST_TOKEN_NO_DOMAIN }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
- name: Deploy
run: make deploy
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_SA_KEY }}
AUTH0_ADDRESS_NO_DOMAIN: ${{ secrets.AUTH0_ADDRESS_NO_DOMAIN }}
AUTH0_AUDIENCE_NO_DOMAIN: ${{ secrets.AUTH0_AUDIENCE_NO_DOMAIN }}
Expand Down
2 changes: 2 additions & 0 deletions gcp/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DB_USER = os.environ["USER_ANALYTICS_DB_USERNAME"]
DBPW = os.environ["USER_ANALYTICS_DB_PASSWORD"]
DB_CONN = os.environ["USER_ANALYTICS_DB_CONNECTION_NAME"]
ANTHROPIC = os.environ["ANTHROPIC_API_KEY"]

# Export GAE to to .gac.json

Expand All @@ -33,6 +34,7 @@
dockerfile = dockerfile.replace(".dbuser", DB_USER)
dockerfile = dockerfile.replace(".dbpw", DBPW)
dockerfile = dockerfile.replace(".dbconn", DB_CONN)
dockerfile = dockerfile.replace(".anthropic", ANTHROPIC)

with open(dockerfile_location, "w") as f:
f.write(dockerfile)
1 change: 1 addition & 0 deletions gcp/policyengine_household_api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ENV AUTH0_TEST_TOKEN_NO_DOMAIN .test-token
ENV USER_ANALYTICS_DB_USERNAME .dbuser
ENV USER_ANALYTICS_DB_PASSWORD .dbpw
ENV USER_ANALYTICS_DB_CONNECTION_NAME .dbconn
ENV ANTHROPIC_API_KEY .anthropic

WORKDIR /app

Expand Down
8 changes: 8 additions & 0 deletions policyengine_household_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .endpoints import (
get_home,
get_calculate,
generate_ai_explainer,
)

# Configure authentication
Expand Down Expand Up @@ -81,6 +82,13 @@ def calculate(country_id):
return get_calculate(country_id)


@app.route("/<country_id>/ai_analysis", methods=["POST"])
@require_auth(None)
@log_analytics
def ai_analysis(country_id):
return generate_ai_explainer(country_id)


@app.route("/liveness_check", methods=["GET"])
def liveness_check():
return flask.Response(
Expand Down
47 changes: 44 additions & 3 deletions policyengine_household_api/country.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
from flask import Response
import json
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
from policyengine_household_api.constants import COUNTRY_PACKAGE_VERSIONS
from typing import Union
from policyengine_household_api.utils import get_safe_json
from policyengine_household_api.utils import (
get_safe_json,
generate_tracer_output,
)
from policyengine_household_api.models.tracer import Tracer
from policyengine_core.parameters import (
ParameterNode,
Parameter,
Expand All @@ -16,6 +21,7 @@
from policyengine_core.periods import instant
import dpath
import math
from uuid import uuid4
import policyengine_uk
import policyengine_us
import policyengine_canada
Expand Down Expand Up @@ -288,7 +294,23 @@ def build_entities(self) -> dict:
data[entity.key] = entity_data
return data

def calculate(self, household: dict, reform: Union[dict, None] = None):
def calculate(
self,
household: dict,
reform: Union[dict, None] = None,
trace: bool = False,
) -> tuple[dict, Tracer | None]:
"""
Calculate a household under a policy reform.

Args:
household (dict): The household data.
reform (dict): The policy reform.
trace (bool): Whether to trace the calculation's computation tree; defaults to False

Returns:
tuple[dict, Tracer | None]: The calculated household and the tracer object, if trace is True, else None.
"""
if reform is not None and len(reform.keys()) > 0:
system = self.tax_benefit_system.clone()
for parameter_name in reform:
Expand Down Expand Up @@ -319,6 +341,9 @@ def calculate(self, household: dict, reform: Union[dict, None] = None):

household = json.loads(json.dumps(household))

# Run tracer on household if requested
if trace:
simulation.trace = True
requested_computations = get_requested_computations(household)

for (
Expand Down Expand Up @@ -381,7 +406,23 @@ def calculate(self, household: dict, reform: Union[dict, None] = None):
f"Error computing {variable_name} for {entity_id}: {e}"
)

return household
# Execute all household computation tree tracer operations
tracer: Tracer | None = None
if trace:
try:

# Generate tracer output
log_lines: list = generate_tracer_output(simulation)

# Take the tracer output and create a new tracer object
tracer = Tracer(self.country_id, tracer=log_lines)

except Exception as e:
# Do something here
print(f"Error computing tracer output: {e}")

# Return the household
return household, tracer


def create_policy_reform(policy_data: dict) -> dict:
Expand Down
1 change: 1 addition & 0 deletions policyengine_household_api/endpoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .home import get_home
from .household import get_calculate
from .ai_explainer import generate_ai_explainer
128 changes: 128 additions & 0 deletions policyengine_household_api/endpoints/ai_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import logging
from flask import request, Response, stream_with_context
from typing import Generator
from uuid import UUID
from policyengine_household_api.country import COUNTRIES
from policyengine_household_api.models.tracer import Tracer
from policyengine_household_api.utils.validate_country import validate_country
from policyengine_household_api.utils.tracer import (
trigger_buffered_ai_analysis,
trigger_streaming_ai_analysis,
prompt_template,
)


@validate_country
def generate_ai_explainer(country_id: str) -> Response:
"""
Generate an AI explainer output for a given variable in
a particular household.

If a UUID is provided, we will fetch the calculation tree
tracer output from the Google Cloud bucket
and use that to generate the AI explainer output.

If a household object is provided, we will generate the
calculation tree tracer output, then use that to generate
AI explainer without storing.

If both are provided, household object takes precendence.

Args:
country_id (str): The country ID.

Request Args:
variable (str): The variable to explain.
household (dict): A household data object, if user does not
want us to store calculation tree tracer.
policy(dict): A policy object, if user doest not want us
to store calculation tree tracer.
uuid (str): The UUID of the tracer output, if user previously
calculated household and had us store calculation tree.
use_streaming (bool): Whether to use streaming Claude

Returns:
Response: The AI explainer output or an error.
"""

country = COUNTRIES.get(country_id)

payload = request.json

# Pull the UUID and variable from the query parameters
variable: str = payload.get("variable")
result_id: str | None = payload.get("result_id", None)
household: dict = payload.get("household", {})
policy: dict = payload.get("policy", {})
use_streaming: bool = payload.get("use_streaming", False)

# If household is provided, calculate the household and tracer and do not save
try:
if household:
_: dict # The calculated household results that we will not use
tracer_data: Tracer
_, tracer_data = country.calculate(household, policy, trace=True)
else:
tracer_data: Tracer = Tracer(country_id, tracer_uuid=result_id)
except Exception as e:
logging.exception(e)
return Response(
json.dumps(
dict(
status="error",
message=f"Error fetching tracer data: {e}",
)
),
status=500,
mimetype="application/json",
)

# Parse the tracer for the calculation tree of the variable
try:
tracer_segment: list[str] = tracer_data.parse_tracer_output(variable)
except Exception as e:
logging.exception(e)
return Response(
json.dumps(
dict(
status="error",
message=f"Error parsing tracer output: {e}",
)
),
status=500,
mimetype="application/json",
)

try:
# Generate the AI explainer prompt using the variable calculation tree
prompt = prompt_template.format(
variable=variable, tracer_segment=tracer_segment
)

# Pass all of this to Claude
if use_streaming:
analysis: Generator = trigger_streaming_ai_analysis(prompt)
return Response(
stream_with_context(analysis),
status=200,
)

analysis: str = trigger_buffered_ai_analysis(prompt)
return Response(
json.dumps({"response": analysis}),
status=200,
)

except Exception as e:
logging.exception(e)
return Response(
json.dumps(
dict(
status="error",
message=f"Error generating tracer analysis result using Claude: {e}",
)
),
status=500,
mimetype="application/json",
)
47 changes: 32 additions & 15 deletions policyengine_household_api/endpoints/household.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
from policyengine_household_api.country import (
COUNTRIES,
validate_country,
)
import json
from flask import Response, request
from uuid import UUID
from policyengine_household_api.country import COUNTRIES
from policyengine_household_api.models.tracer import Tracer
from policyengine_household_api.utils.validate_country import validate_country
import json
import logging


def get_calculate(country_id: str, add_missing: bool = False) -> dict:
"""Lightweight endpoint for passing in household JSON objects and calculating without storing data.
@validate_country
def get_calculate(country_id: str, add_missing: bool = False) -> Response:
"""
API endpoint for calculating households over specified economies, with optional data
storage (disabled by default) for computation tree tracer objects.

Args:
country_id (str): The country ID.
add_missing (bool = False): Whether or not to populate all
possible variables into the household object; this is a special
use case and should usually be kept at its default setting.
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
household_json = payload.get("household", {})
policy_json = payload.get("policy", {})
trace: bool = payload.get("save_computation_tree", False)

country = COUNTRIES.get(country_id)

try:
result = country.calculate(household_json, policy_json)
result: dict
tracer: Tracer | None
result, tracer = country.calculate(
household_json, policy_json, trace=trace
)

if tracer:
tracer.upload_to_cloud_storage()

except Exception as e:
logging.exception(e)
response_body = dict(
Expand All @@ -40,8 +50,15 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict:
mimetype="application/json",
)

return dict(
status="ok",
message=None,
result=result,
return Response(
json.dumps(
dict(
status="ok",
message=None,
result=result,
result_id=str(tracer.tracer_uuid) if tracer else None,
)
),
200,
mimetype="application/json",
)
Loading
Loading