Skip to content

Commit

Permalink
Merge pull request #12 from uclahs-cds/aholmes-polish-db
Browse files Browse the repository at this point in the history
Add examples and other supporting code for the database module scaffolding
  • Loading branch information
aholmes authored Oct 2, 2023
2 parents a757cae + dbbe2d5 commit f82a5e2
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 19 deletions.
17 changes: 16 additions & 1 deletion src/database/BL_Python/database/engine/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import create_engine, event
from sqlalchemy.orm.scoping import ScopedSession
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.pool import Pool, StaticPool


class SQLiteScopedSession(ScopedSession):
Expand All @@ -15,8 +16,22 @@ def create(
"""
Create a new session factory for SQLite.
"""
poolclass: type[Pool] | None = None
# if the connection string is an SQLite in-memory database
# then make SQLAlchemy maintain a static pool of "connections"
# so that the in-memory database is not deallocated. Otherwise,
# the database would disappear when a thread is done with it.
# Note: SQLite will reject usage from other threads unless
# the connection string also contains `?check_same_thread=False`,
# e.g. `sqlite:///:memory:?check_same_thread=False`
if ":memory:" in connection_string:
poolclass = StaticPool

engine = create_engine(
connection_string, echo=echo, execution_options=execution_options or {}
connection_string,
echo=echo,
execution_options=execution_options or {},
poolclass=poolclass,
)

return SQLiteScopedSession(
Expand Down
20 changes: 12 additions & 8 deletions src/web/BL_Python/web/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Flask entry point.
"""
import logging
from logging import Logger
from os import environ, path
from typing import Any, Optional, cast

Expand All @@ -22,7 +22,7 @@
# from CAP.app.services.user.login_manager import LoginManager
# from CAP.database.models.CAP import Base
from connexion.apps.flask_app import FlaskApp
from flask import Flask
from flask import Flask, url_for
from injector import Module
from lib_programname import get_path_executed_script

Expand All @@ -49,7 +49,7 @@ def create_app(
# just grow and grow.
# startup_builder: IStartupBuilder,
# config: Config,
):
) -> Flask:
"""
Bootstrap the Flask applcation.
Expand Down Expand Up @@ -115,6 +115,14 @@ def create_app(
flask_injector = configure_dependencies(app, application_modules=modules)
app.injector = flask_injector

if config.flask.openapi is not None and config.flask.openapi.use_swagger:
with app.app_context():
# use this logger so we can control where output is sent.
# the default logger retrieved here logs to the console.
app.injector.injector.get(Logger).info(
f"Swagger UI can be accessed at {url_for('/./_swagger_ui_index', _external=True)}"
)

return app


Expand Down Expand Up @@ -166,11 +174,7 @@ def configure_openapi(config: Config, name: Optional[str] = None):
# json_logging.config_root_logger()
app.logger.setLevel(environ.get("LOGLEVEL", "INFO").upper())

options: dict[str, bool] = {}
# TODO document that connexion[swagger-ui] must be installed
# for this to work
if config.flask.openapi.use_swagger:
options["swagger_ui"] = True
options: dict[str, bool] = {"swagger_ui": config.flask.openapi.use_swagger}

connexion_app.add_api(
f"{config.flask.app_name}/{config.flask.openapi.spec_path}",
Expand Down
1 change: 1 addition & 0 deletions src/web/BL_Python/web/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _update_flask_config(self, flask_app_config: FlaskAppConfig):

class ConfigObject:
ENV = self.env
SERVER_NAME = f"{self.host}:{self.port}"

flask_app_config.from_object(ConfigObject)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
# isort: off
from {{application_name}}._version import __version__
# isort: on

{% if module.database %}
from typing import Any, cast
from injector import Injector
from sqlalchemy import MetaData

from sqlalchemy.orm import Session
from {{application_name}}.modules.database import Base
{% endif %}

def create_app():
application_configs = []
Expand All @@ -13,8 +24,29 @@ def create_app():

from BL_Python.web.application import create_app as _create_app
# fmt: off
return _create_app(
app = _create_app(
application_configs=application_configs,
application_modules=application_modules
)
# fmt: on

{% if module.database %}
# For now, create the database and tables
# when the application starts. This behavior
# will be removed when Alembic is integrated.
session = cast(Injector, app.injector.injector).get(Session)
cast(MetaData, Base.metadata).create_all(session.bind) # pyright: ignore[reportGeneralTypeIssues]

{#
ideally this would use @inject w/ session: Session,
but something is preventing it from running or
sending in the dependencies to remove_db.
For now, just resolve it directly.
#}
@app.teardown_request
def remove_db(exception: Any):
session = app.injector.injector.get(Session)
session.rollback()
{% endif %}

return app
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def root(flask: Flask, config: Config, log: Logger):
if config['ENV'] != 'debug':
return "", 405

{#
TODO this does not group by like-URLs that are used for different methods.
for example, get_foo() and post_foo() might both use the URL /foo, but the
HTTP verbs GET (for get_) and POST (for post_). This makes the table look
awkward, so it may be prudent to think about how to represent that.
#}
# in debug environments, print a table of all routes and their allowed methods
output = "<table><tr><th>url</th><th>methods</th></tr>"
for rule in flask.url_map.iter_rules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ paths:
{% for endpoint in endpoints %}
/{{endpoint.endpoint_name}}:
get:
{% if module.database %}
description: "Get all entries in the {{endpoint.endpoint_name}} table."
{% else %}
description: "Hello, {{endpoint.endpoint_name}}!"
{% endif %}
operationId: "endpoints.{{endpoint.endpoint_name}}.get_{{endpoint.endpoint_name}}"
responses:
"200":
Expand All @@ -32,4 +36,26 @@ paths:
type: string
description: "Endpoint is working correctly."
summary: "A simple method that returns 200 as long as the endpoint is working correctly."
{% if module.database %}
post:
description: Add a new {{endpoint.endpoint_name.capitalize()}} to the {{endpoint.endpoint_name}} table.
operationId: "endpoints.{{endpoint.endpoint_name}}.add_{{endpoint.endpoint_name}}"
requestBody:
content:
application/json:
schema:
type: object
required:
- name
properties:
name:
type: string
responses:
201:
description: "The new {{endpoint.endpoint_name.capitalize()}} was successfully added to the {{endpoint.endpoint_name}} table."
content:
application/json:
schema:
type: string
{% endif %}
{% endfor %}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ from logging import Logger

from flask import Blueprint
from injector import inject
{% if module.database %}
from flask import request
from sqlalchemy.orm import Session
{% endif %}

{% if module.database %}
from {{application_name}}.modules.database import {{endpoint.endpoint_name.capitalize()}}
{% endif %}

{% if template_type != "openapi" %}
{{endpoint.endpoint_name}}_blueprint = Blueprint("{{endpoint.endpoint_name}}", __name__, url_prefix="/{{endpoint.endpoint_name}}")
Expand All @@ -12,5 +20,36 @@ from injector import inject
{% if template_type != "openapi" %}
@{{endpoint.endpoint_name}}_blueprint.route("/")
{% endif %}
{% if module.database %}
def get_{{endpoint.endpoint_name}}(session: Session, log: Logger):
entries = session.query({{endpoint.endpoint_name.capitalize()}}).all()
return {"names": [x.name for x in entries]}
{% else %}
def get_{{endpoint.endpoint_name}}(log: Logger):
return "Hello, {{endpoint.endpoint_name}}!"
{% endif %}

{% if module.database %}
@inject
{% if template_type != "openapi" %}
@{{endpoint.endpoint_name}}_blueprint.route("/", methods=["POST"])
{% endif %}
def add_{{endpoint.endpoint_name}}(session: Session, log: Logger):
name: str|None = None
try:
data = request.get_json(force=True)
if data is None or not (name := data["name"]):
return "Request JSON must contain a 'name' field with a string value.", 400

session.add({{endpoint.endpoint_name.capitalize()}}(name = name))
except Exception as e:
log.critical(str(e), exc_info=True)
return "Request JSON must contain a 'name' field with a string value.", 400

try:
session.commit()
return f"Added {name} to the {{endpoint.endpoint_name.capitalize()}} table.", 201
except Exception as e:
log.critical(str(e), exc_info=True)
return "An error occurred! Check the application logs.", 500
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ def on_create(config: dict[str, Any], log: Logger):
config["module"]["database"] = {}

connection_string = input(
"\nEnter a database connection string.\nBy default this is `sqlite:///:memory:`.\nRetain this default by pressing enter, or type something else.\n> "
"\nEnter a database connection string.\nBy default this is `sqlite:///:memory:?check_same_thread=False`.\nRetain this default by pressing enter, or type something else.\n> "
)

config["module"]["database"]["connection_string"] = (
connection_string if connection_string else "sqlite:///:memory:"
connection_string
if connection_string
else "sqlite:///:memory:?check_same_thread=False"
)
log.info(
f"Using database connection string `{config['module']['database']['connection_string']}`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()


class Person(Base):
__tablename__ = "person"
{% for endpoint in endpoints %}
class {{endpoint.endpoint_name.capitalize()}}(Base):
__tablename__ = "{{endpoint.endpoint_name}}"
id = Column(Integer, primary_key=True)
name = Column(String(50))

def to_dict(self):
return {"id": self.id, "name": self.name}
def __repr__(self):
return f"<{{endpoint.endpoint_name.capitalize()}} {self.name}>"
{% endfor %}
3 changes: 1 addition & 2 deletions src/web/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ dependencies = [
"flask-login == 0.6.2",
"types-flask == 1.1.6",
"connexion == 2.14.2",
# TODO is this one necessary, or should it perhaps belong in the app?
"connexion[swagger-ui-bundle] == 2.14.2",
"swagger_ui_bundle==0.0.9",
# specific version because Jinja (Flask dependency)
# does not specify an upper bound for MarkupSafe,
# and it pulls in a breaking change that prevents
Expand Down

0 comments on commit f82a5e2

Please sign in to comment.