From d3d05789f7172a97ef52993d8a577b2b13396899 Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sat, 24 Feb 2024 13:43:13 +0000 Subject: [PATCH 1/6] Changes --- app/__init__.py | 16 ++++++++++------ app/config.py | 7 ++++--- tests/conftest.py | 10 +++++----- tests/test_auth.py | 7 +++++++ tests/test_main.py | 10 ---------- 5 files changed, 26 insertions(+), 24 deletions(-) create mode 100644 tests/test_auth.py diff --git a/app/__init__.py b/app/__init__.py index 429306e..6f139dc 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,3 +1,5 @@ +import os + from flask import Flask from flask_login import LoginManager from flask_mail import Mail @@ -5,7 +7,7 @@ from flask_wtf.csrf import CSRFProtect from itsdangerous import URLSafeTimedSerializer -from app.config import Config, selected_config +from app.config import CONFIGS, Config db = SQLAlchemy() csrf = CSRFProtect() @@ -15,6 +17,7 @@ login_manager.login_view = "/" login_manager.login_message = None +selected_config = CONFIGS[os.environ["ENV"]] def create_app(config: Config = selected_config): app = Flask(__name__) @@ -26,18 +29,19 @@ def create_app(config: Config = selected_config): login_manager.init_app(app) app.serialiser = URLSafeTimedSerializer(app.config["SECRET_KEY"]) - # Reset database - from app.models import User, insertDummyData - + # Reset database when not in production if app.config["RESET_DB"]: with app.app_context(): db.drop_all() db.create_all() - insertDummyData() + + # Insert fake data + if app.config["FAKE_DATA"]: + from app.models import insertDummyData + insertDummyData() # Register blueprints from app.views import auth, main - app.register_blueprint(main.bp) app.register_blueprint(auth.bp) diff --git a/app/config.py b/app/config.py index 47606de..139fbad 100644 --- a/app/config.py +++ b/app/config.py @@ -25,18 +25,21 @@ class Config(object): class DevConfig(Config): DEBUG: bool = True RESET_DB: bool = True + FAKE_DATA: bool = True SQLALCHEMY_DATABASE_URI: str = "sqlite:///" + os.path.join(basedir, "mindli.sqlite") class ProdConfig(Config): DEBUG: bool = False RESET_DB: bool = False + FAKE_DATA: bool = False SQLALCHEMY_DATABASE_URI: str = os.environ["DATABASE_URL"] class TestConfig(Config): TESTING: bool = True - RESET_DB: bool = False + RESET_DB: bool = True + FAKE_DATA: bool = False SQLALCHEMY_DATABASE_URI: str = "sqlite://" # Use in-memory database @@ -45,5 +48,3 @@ class TestConfig(Config): "prod": ProdConfig, "test": TestConfig, } - -selected_config = CONFIGS[os.environ["ENV"]] diff --git a/tests/conftest.py b/tests/conftest.py index 25fdf23..325ac60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,17 @@ +from typing import Any, Generator + import pytest from flask import Flask from flask.testing import FlaskClient -from app import create_app, db +from app import create_app from app.config import TestConfig @pytest.fixture() -def app() -> Flask: +def app() -> Generator[Flask, Any, None]: app = create_app(config=TestConfig) - with app.app_context(): - db.create_all() - return app + yield app @pytest.fixture() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..710b247 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,7 @@ +from flask.testing import FlaskClient + + +def test_get_login(client: FlaskClient): + get_response = client.get("/login") + assert get_response.status_code == 200 + diff --git a/tests/test_main.py b/tests/test_main.py index 88da129..fb55eb5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,20 +2,10 @@ def test_get_index(client: FlaskClient): - """ - GIVEN a Flask client configured for testing - WHEN the '/' page is requested (GET) - THEN check that the response is valid - """ get_response = client.get("/") assert get_response.status_code == 200 def test_post_index(client: FlaskClient): - """ - GIVEN a Flask client configured for testing - WHEN the '/' page is requested (POST) - THEN check that the response is forbidden - """ post_response = client.post("/") assert post_response.status_code == 405 From 07c04d03fbcfaa20295acbdc66fa0f05bbfdf3a8 Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sat, 24 Feb 2024 14:06:41 +0000 Subject: [PATCH 2/6] Added Flask-Migrate support --- app/__init__.py | 15 +- app/config.py | 6 - migrations/README | 1 + migrations/alembic.ini | 50 +++++++ migrations/env.py | 113 +++++++++++++++ migrations/script.py.mako | 24 ++++ migrations/versions/afe16cfee729_.py | 202 +++++++++++++++++++++++++++ requirements.txt | 8 +- 8 files changed, 400 insertions(+), 19 deletions(-) create mode 100644 migrations/README create mode 100644 migrations/alembic.ini create mode 100644 migrations/env.py create mode 100644 migrations/script.py.mako create mode 100644 migrations/versions/afe16cfee729_.py diff --git a/app/__init__.py b/app/__init__.py index 6f139dc..d995fc3 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -3,6 +3,7 @@ from flask import Flask from flask_login import LoginManager from flask_mail import Mail +from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from flask_wtf.csrf import CSRFProtect from itsdangerous import URLSafeTimedSerializer @@ -10,9 +11,11 @@ from app.config import CONFIGS, Config db = SQLAlchemy() +migrate = Migrate() csrf = CSRFProtect() mail = Mail() + login_manager = LoginManager() login_manager.login_view = "/" login_manager.login_message = None @@ -25,21 +28,11 @@ def create_app(config: Config = selected_config): # Initialise extensions db.init_app(app) + migrate.init_app(app, db) mail.init_app(app) login_manager.init_app(app) app.serialiser = URLSafeTimedSerializer(app.config["SECRET_KEY"]) - # Reset database when not in production - if app.config["RESET_DB"]: - with app.app_context(): - db.drop_all() - db.create_all() - - # Insert fake data - if app.config["FAKE_DATA"]: - from app.models import insertDummyData - insertDummyData() - # Register blueprints from app.views import auth, main app.register_blueprint(main.bp) diff --git a/app/config.py b/app/config.py index 139fbad..a653733 100644 --- a/app/config.py +++ b/app/config.py @@ -24,22 +24,16 @@ class Config(object): class DevConfig(Config): DEBUG: bool = True - RESET_DB: bool = True - FAKE_DATA: bool = True SQLALCHEMY_DATABASE_URI: str = "sqlite:///" + os.path.join(basedir, "mindli.sqlite") class ProdConfig(Config): DEBUG: bool = False - RESET_DB: bool = False - FAKE_DATA: bool = False SQLALCHEMY_DATABASE_URI: str = os.environ["DATABASE_URL"] class TestConfig(Config): TESTING: bool = True - RESET_DB: bool = True - FAKE_DATA: bool = False SQLALCHEMY_DATABASE_URI: str = "sqlite://" # Use in-memory database diff --git a/migrations/README b/migrations/README new file mode 100644 index 0000000..0e04844 --- /dev/null +++ b/migrations/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/migrations/alembic.ini b/migrations/alembic.ini new file mode 100644 index 0000000..ec9d45c --- /dev/null +++ b/migrations/alembic.ini @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..4c97092 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,113 @@ +import logging +from logging.config import fileConfig + +from flask import current_app + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + try: + # this works with Flask-SQLAlchemy<3 and Alchemical + return current_app.extensions['migrate'].db.get_engine() + except (TypeError, AttributeError): + # this works with Flask-SQLAlchemy>=3 + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=get_metadata(), literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + conf_args = current_app.extensions['migrate'].configure_args + if conf_args.get("process_revision_directives") is None: + conf_args["process_revision_directives"] = process_revision_directives + + connectable = get_engine() + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=get_metadata(), + **conf_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 0000000..2c01563 --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/afe16cfee729_.py b/migrations/versions/afe16cfee729_.py new file mode 100644 index 0000000..5994211 --- /dev/null +++ b/migrations/versions/afe16cfee729_.py @@ -0,0 +1,202 @@ +"""empty message + +Revision ID: afe16cfee729 +Revises: +Create Date: 2024-02-24 13:57:36.774279 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'afe16cfee729' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('intervention', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=50), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.create_table('issue', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=50), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.create_table('language', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=50), nullable=False), + sa.Column('iso639_1', sa.String(length=2), nullable=True), + sa.Column('iso639_2', sa.String(length=3), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('iso639_1'), + sa.UniqueConstraint('iso639_2'), + sa.UniqueConstraint('name') + ) + op.create_table('user', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('email', sa.String(length=254), nullable=False), + sa.Column('password_hash', sa.String(length=255), nullable=False), + sa.Column('first_name', sa.String(length=50), nullable=False), + sa.Column('last_name', sa.String(length=50), nullable=False), + sa.Column('date_joined', sa.Date(), nullable=False), + sa.Column('role', sa.Enum('CLIENT', 'THERAPIST', name='userrole'), nullable=False), + sa.Column('verified', sa.Boolean(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=False), + sa.Column('gender', sa.Enum('MALE', 'FEMALE', 'NON_BINARY', name='gender'), nullable=True), + sa.Column('photo_url', sa.String(length=255), nullable=True), + sa.Column('timezone', sa.String(length=50), nullable=True), + sa.Column('currency', sa.String(length=3), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('user', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_user_email'), ['email'], unique=True) + + op.create_table('client', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('preferred_gender', sa.Enum('MALE', 'FEMALE', 'NON_BINARY', name='gender'), nullable=True), + sa.Column('preferred_language_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['preferred_language_id'], ['language.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('client', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_client_user_id'), ['user_id'], unique=False) + + op.create_table('therapist', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('country', sa.String(length=50), nullable=False), + sa.Column('affilitation', sa.Text(), nullable=True), + sa.Column('bio', sa.Text(), nullable=True), + sa.Column('link', sa.String(length=255), nullable=True), + sa.Column('location', sa.String(length=255), nullable=True), + sa.Column('registrations', sa.Text(), nullable=True), + sa.Column('qualifications', sa.Text(), nullable=True), + sa.Column('years_of_experience', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('therapist', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_therapist_user_id'), ['user_id'], unique=False) + + op.create_table('availability', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('day_of_week', sa.Integer(), nullable=True), + sa.Column('start_time', sa.Time(), nullable=True), + sa.Column('end_time', sa.Time(), nullable=True), + sa.Column('specific_date', sa.Date(), nullable=True), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('availability', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_availability_therapist_id'), ['therapist_id'], unique=False) + + op.create_table('client_issue', + sa.Column('client_id', sa.Integer(), nullable=False), + sa.Column('issue_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['client_id'], ['client.id'], ), + sa.ForeignKeyConstraint(['issue_id'], ['issue.id'], ), + sa.PrimaryKeyConstraint('client_id', 'issue_id') + ) + op.create_table('session_type', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('session_duration', sa.Integer(), nullable=False), + sa.Column('fee_amount', sa.Float(), nullable=False), + sa.Column('fee_currency', sa.String(length=3), nullable=False), + sa.Column('session_format', sa.Enum('FACE', 'AUDIO', 'VIDEO', name='sessionformat'), nullable=True), + sa.Column('notes', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('session_type', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_session_type_therapist_id'), ['therapist_id'], unique=False) + + op.create_table('therapist_format', + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('session_format', sa.Enum('FACE', 'AUDIO', 'VIDEO', name='sessionformat'), nullable=False), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('therapist_id', 'session_format') + ) + op.create_table('therapist_intervention', + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('intervention_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['intervention_id'], ['intervention.id'], ), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('therapist_id', 'intervention_id') + ) + op.create_table('therapist_issue', + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('issue_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['issue_id'], ['issue.id'], ), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('therapist_id', 'issue_id') + ) + op.create_table('therapist_language', + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('language_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['language_id'], ['language.id'], ), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('therapist_id', 'language_id') + ) + op.create_table('unavailability', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('therapist_id', sa.Integer(), nullable=False), + sa.Column('start_date', sa.Date(), nullable=False), + sa.Column('end_date', sa.Date(), nullable=False), + sa.Column('reason', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['therapist_id'], ['therapist.id'], ), + sa.PrimaryKeyConstraint('id') + ) + with op.batch_alter_table('unavailability', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_unavailability_therapist_id'), ['therapist_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('unavailability', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_unavailability_therapist_id')) + + op.drop_table('unavailability') + op.drop_table('therapist_language') + op.drop_table('therapist_issue') + op.drop_table('therapist_intervention') + op.drop_table('therapist_format') + with op.batch_alter_table('session_type', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_session_type_therapist_id')) + + op.drop_table('session_type') + op.drop_table('client_issue') + with op.batch_alter_table('availability', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_availability_therapist_id')) + + op.drop_table('availability') + with op.batch_alter_table('therapist', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_therapist_user_id')) + + op.drop_table('therapist') + with op.batch_alter_table('client', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_client_user_id')) + + op.drop_table('client') + with op.batch_alter_table('user', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_user_email')) + + op.drop_table('user') + op.drop_table('language') + op.drop_table('issue') + op.drop_table('intervention') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 74ea092..0c63cf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,21 @@ +alembic==1.12.1 blinker==1.6.3 click==8.1.7 exceptiongroup==1.2.0 Flask==2.2.5 Flask-Login==0.6.3 Flask-Mail==0.9.1 -Flask-SQLAlchemy==3.0.5 +Flask-Migrate==4.0.5 +flask-sqlalchemy==3.0.5 Flask-WTF==1.1.1 greenlet==3.0.3 importlib-metadata==6.7.0 +importlib-resources==5.12.0 iniconfig==2.0.0 isort==5.11.5 itsdangerous==2.1.2 Jinja2==3.1.2 +Mako==1.2.4 MarkupSafe==2.1.3 mypy-extensions==1.0.0 packaging==23.2 @@ -24,7 +28,7 @@ python-dotenv==0.21.1 SQLAlchemy==2.0.23 tomli==2.0.1 typed-ast==1.5.5 -typing_extensions==4.7.1 +typing-extensions==4.7.1 Werkzeug==2.2.3 WTForms==3.0.1 zipp==3.15.0 From 5ac1b1565d04d37319734c1a4046554c9ca69d5c Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sun, 25 Feb 2024 18:33:49 +0000 Subject: [PATCH 3/6] Unit tests for register and login routes --- app/models.py | 11 --- app/static/css/main.css | 5 +- app/static/js/forms.js | 21 +++-- app/utils/password.py | 14 --- app/utils/validators.py | 17 ++++ app/views/auth.py | 52 ++++++----- migrations/env.py | 3 +- migrations/versions/afe16cfee729_.py | 3 +- tests/conftest.py | 51 ++++++++++- tests/test_auth.py | 7 -- tests/test_login.py | 73 +++++++++++++++ tests/test_main.py | 8 +- tests/test_register.py | 130 +++++++++++++++++++++++++++ 13 files changed, 317 insertions(+), 78 deletions(-) delete mode 100644 app/utils/password.py create mode 100644 app/utils/validators.py delete mode 100644 tests/test_auth.py create mode 100644 tests/test_login.py create mode 100644 tests/test_register.py diff --git a/app/models.py b/app/models.py index 08a0791..8f766f4 100644 --- a/app/models.py +++ b/app/models.py @@ -182,17 +182,6 @@ class Unavailability(db.Model): def insertDummyData() -> None: users: List[User] = [ - User( - email="client@example.com", - password_hash=generate_password_hash("password"), - first_name="John", - last_name="Smith", - date_joined=date.today(), - role=UserRole.CLIENT, - verified=True, - active=True, - gender=Gender.MALE, - ), User( email="therapist@example.com", password_hash=generate_password_hash("password"), diff --git a/app/static/css/main.css b/app/static/css/main.css index a6ba030..95af38a 100644 --- a/app/static/css/main.css +++ b/app/static/css/main.css @@ -213,7 +213,10 @@ input:-webkit-autofill:focus { } .btn-primary, -.btn-primary:disabled { +.btn-primary.active, +.btn-primary.show, +.btn-primary:disabled, +:not(.btn-check)+.btn:active { color: white; background: var(--colour-primary); } diff --git a/app/static/js/forms.js b/app/static/js/forms.js index 669beff..ed10f03 100644 --- a/app/static/js/forms.js +++ b/app/static/js/forms.js @@ -13,12 +13,6 @@ function showLoadingBtn(isLoading) { // Function to display error messages function displayFormErrors(errors) { - - // Clear previous errors - $('.error-message').remove(); - $('.input-error').removeClass('input-error'); - - // Display new errors below corresponding input fields for (const key in errors) { const inputField = $('#' + key); const errorMessage = $( @@ -33,10 +27,19 @@ function displayFormErrors(errors) { } function ajaxFormResponseHandler(response) { - if (response.errors) { - displayFormErrors(response.errors); + + // Clear previous errors + $('.error-message').remove(); + $('.input-error').removeClass('input-error'); + + if (response.success) { + if (response.url) { + window.location = response.url; + } } else { - window.location = response.url; + if (response.errors) { + displayFormErrors(response.errors); + } } } diff --git a/app/utils/password.py b/app/utils/password.py deleted file mode 100644 index 23195e2..0000000 --- a/app/utils/password.py +++ /dev/null @@ -1,14 +0,0 @@ -def isValidPassword(password: str) -> bool: - """ - Returns whether a given password meets the security requirements - (at least 8 characters with at least one digit, one uppercase letter - and one lowercase letter)""" - if ( - (len(password) < 8) - or (not any(char.isdigit() for char in password)) - or (not any(char.isupper() for char in password)) - or (not any(char.islower() for char in password)) - ): - return False - else: - return True diff --git a/app/utils/validators.py b/app/utils/validators.py new file mode 100644 index 0000000..43ec233 --- /dev/null +++ b/app/utils/validators.py @@ -0,0 +1,17 @@ +import re + +def isValidText(text: str) -> bool: + return text and not text.isspace() + +def isValidEmail(email: str) -> bool: + email_regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + return email and re.match(email_regex, email) + +def isValidPassword(password: str) -> bool: + return ( + password + and len(password) >= 8 + and any(char.isdigit() for char in password) + and any(char.isupper() for char in password) + and any(char.islower() for char in password) + ) \ No newline at end of file diff --git a/app/views/auth.py b/app/views/auth.py index 17091b7..45ff3b5 100644 --- a/app/views/auth.py +++ b/app/views/auth.py @@ -4,13 +4,12 @@ render_template, request, session, url_for) from flask_login import login_user, logout_user from itsdangerous import BadSignature, SignatureExpired -from markupsafe import escape from werkzeug.security import check_password_hash, generate_password_hash from app import db, mail from app.models import User, UserRole from app.utils.mail import EmailMessage, EmailSubject -from app.utils.password import isValidPassword +from app.utils.validators import isValidEmail, isValidPassword, isValidText bp = Blueprint("auth", __name__) @@ -37,26 +36,27 @@ def register() -> Response: # Get form data role = request.form.get("role") - first_name = escape(request.form.get("first_name")) - last_name = escape(request.form.get("last_name")) + first_name = request.form.get("first_name") + last_name = request.form.get("last_name") email = request.form.get("email") password = request.form.get("password") + # Validate input - if not role: + if not role or (role not in set(r.value for r in UserRole)): errors["role"] = "Account type is required" - if not first_name or first_name.isspace(): + if not isValidText(first_name): errors["first_name"] = "First name is required" - if not last_name or last_name.isspace(): + if not isValidText(last_name): errors["last_name"] = "Last name is required" - if not email or email.isspace(): - errors["email"] = "Email is required" - elif User.query.filter_by(email=email.lower()).first(): - errors["email"] = "Email address is already in use" if not isValidPassword(password): errors["password"] = "Password does not meet requirements" + if not isValidEmail(email): + errors["email"] = "Invalid email address" + elif db.session.execute(db.select(User).filter_by(email=email.lower())).scalar_one_or_none() is not None: + errors["email"] = "Email address is already in use" if errors: - return jsonify({"errors": errors}) + return jsonify({"success": False, "errors": errors}) # Proceed with successful registration else: @@ -75,7 +75,7 @@ def register() -> Response: db.session.add(user) db.session.commit() - # Send verification email and redirect + # Send verification email email_message = EmailMessage( mail=mail, subject=EmailSubject.EMAIL_VERIFICATION, @@ -83,8 +83,10 @@ def register() -> Response: serialiser=current_app.serialiser, ) email_message.send() + + # Store email in session for email verification session["email"] = email - return jsonify({"url": url_for("auth.verify_email")}) + return jsonify({"success": True, "url": url_for("auth.verify_email")}) # Request method is GET else: @@ -95,11 +97,13 @@ def register() -> Response: # Logs user in if credentials are valid @bp.route("/login", methods=["GET", "POST"]) def login() -> Response: + if request.method == "POST": + errors = {} # Get form data - email = request.form.get("email").lower() + email = request.form.get("email") password = request.form.get("password") # Validate input @@ -108,21 +112,21 @@ def login() -> Response: if not password: errors["password"] = "Password is required" else: - user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() + user = db.session.execute(db.select(User).filter_by(email=email.lower())).scalar_one_or_none() if user is None or not check_password_hash(user.password_hash, password): errors["password"] = "Incorrect email/password" if errors: - return jsonify({"errors": errors}) + return jsonify({"success": False, "errors": errors}) # Ensure user's email is verified if not user.verified: # Store email in session for verification and redirect session["email"] = email - return jsonify({"url": url_for("auth.verify_email")}) + return jsonify({"success": True, "url": url_for("auth.verify_email")}) # Log user in and redirect to home page login_user(user) - return jsonify({"url": url_for("main.index")}) + return jsonify({"success": True, "url": url_for("main.index")}) # Request method is GET else: @@ -154,7 +158,7 @@ def verify_email() -> Response: ) email_message.send() flash(f"Email verification instructions sent to {user.email}") - return jsonify({"url": url_for("main.index")}) + return jsonify({"success": True, "url": url_for("main.index")}) else: return render_template("verify-email.html", email=session["email"]) @@ -211,7 +215,7 @@ def reset_request() -> Response: # Return errors if any if errors: - return jsonify({"errors": errors}) + return jsonify({"success": False, "errors": errors}) # Send reset email else: @@ -223,7 +227,7 @@ def reset_request() -> Response: ) email_message.send() flash(f"Password reset instructions sent to {email}") - return jsonify({"url": url_for("main.index")}) + return jsonify({"success": True, "url": url_for("main.index")}) # Form submitted to reset password elif request.form.get("form-type") == "reset_password": @@ -241,7 +245,7 @@ def reset_request() -> Response: elif password != password_confirmation: errors["password_confirmation"] = "Passwords do not match" if errors: - return jsonify({"errors": errors}) + return jsonify({"success": False, "errors": errors}) # Successful reset else: @@ -252,7 +256,7 @@ def reset_request() -> Response: # Redirect to login page flash("Success! Your password has been reset") - return jsonify({"url": url_for("main.index")}) + return jsonify({"success": True, "url": url_for("main.index")}) # Request method is GET else: diff --git a/migrations/env.py b/migrations/env.py index 4c97092..0749ebf 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -1,9 +1,8 @@ import logging from logging.config import fileConfig -from flask import current_app - from alembic import context +from flask import current_app # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/migrations/versions/afe16cfee729_.py b/migrations/versions/afe16cfee729_.py index 5994211..47da4ae 100644 --- a/migrations/versions/afe16cfee729_.py +++ b/migrations/versions/afe16cfee729_.py @@ -5,9 +5,8 @@ Create Date: 2024-02-24 13:57:36.774279 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'afe16cfee729' diff --git a/tests/conftest.py b/tests/conftest.py index 325ac60..344db66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,62 @@ +from datetime import date from typing import Any, Generator import pytest from flask import Flask from flask.testing import FlaskClient +from werkzeug.security import generate_password_hash -from app import create_app +from app import create_app, db from app.config import TestConfig +from app.models import Gender, User, UserRole -@pytest.fixture() +@pytest.fixture(scope='module') def app() -> Generator[Flask, Any, None]: + app = create_app(config=TestConfig) - yield app + + with app.app_context(): + db.create_all() + + yield app + + db.session.remove() + db.drop_all() + return -@pytest.fixture() +@pytest.fixture(scope='module') def client(app: Flask) -> FlaskClient: return app.test_client() + + +@pytest.fixture(scope='function') +def fake_user_password() -> str: + return 'ValidPassword1' + + +@pytest.fixture(scope='function') +def fake_user_client(fake_user_password: str) -> Generator[User, Any, None]: + + # Insert test data + fake_user_client = User( + email="client@example.com".lower(), + password_hash=generate_password_hash(fake_user_password), + first_name="John", + last_name="Smith", + date_joined=date.today(), + role=UserRole.CLIENT, + verified=True, + active=True, + gender=Gender.MALE, + ) + db.session.add(fake_user_client) + db.session.commit() + + yield fake_user_client + + db.session.delete(fake_user_client) + db.session.commit() + return + diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 710b247..0000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,7 +0,0 @@ -from flask.testing import FlaskClient - - -def test_get_login(client: FlaskClient): - get_response = client.get("/login") - assert get_response.status_code == 200 - diff --git a/tests/test_login.py b/tests/test_login.py new file mode 100644 index 0000000..4974632 --- /dev/null +++ b/tests/test_login.py @@ -0,0 +1,73 @@ +from unittest.mock import Mock, patch + +from flask.testing import FlaskClient + +from app import db +from app.models import User + + +@patch('app.views.auth.login_user') +def test_get_login(mock_login_user: Mock, client: FlaskClient): + response = client.get("/login") + assert response.status_code == 200 + mock_login_user.assert_not_called() + return + + +@patch('app.views.auth.login_user') +def test_user_login_success(mock_login_user: Mock, client: FlaskClient, fake_user_client: User, fake_user_password: str): + response = client.post("/login", data={ + "email": fake_user_client.email, + "password": fake_user_password, + }) + data = response.get_json() + assert response.status_code == 200 + assert data["success"] is True + assert "url" in data and data["url"] == "/index" + mock_login_user.assert_called_once() + return + + +@patch('app.views.auth.login_user') +def test_user_login_missing_credentials(mock_login_user: Mock, client: FlaskClient): + response = client.post("/login", data={}) + data = response.get_json() + assert response.status_code == 200 + assert data["success"] is False + assert "errors" in data + assert "email" in data["errors"] and "password" in data["errors"] + mock_login_user.assert_not_called() + return + + +@patch('app.views.auth.login_user') +def test_user_login_wrong_credentials(mock_login_user: Mock, client: FlaskClient, fake_user_client: User): + response = client.post("/login", data={ + "email": fake_user_client.email, + "password": "wrongpassword", + }) + data = response.get_json() + assert response.status_code == 200 + assert data["success"] is False + assert "errors" in data + assert "password" in data["errors"] + mock_login_user.assert_not_called() + return + + +@patch('app.views.auth.login_user') +def test_user_login_unverified(mock_login_user: Mock, client: FlaskClient, fake_user_client: User, fake_user_password: str): + fake_user_client.verified = False + db.session.commit() + response = client.post("/login", data={ + "email": fake_user_client.email, + "password": fake_user_password + }) + data = response.get_json() + assert response.status_code == 200 + assert data["success"] is True + assert "url" in data and data["url"] == "/verify-email" + mock_login_user.assert_not_called() + fake_user_client.verified = True + db.session.commit() + return diff --git a/tests/test_main.py b/tests/test_main.py index fb55eb5..d6e67a1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,10 +2,10 @@ def test_get_index(client: FlaskClient): - get_response = client.get("/") - assert get_response.status_code == 200 + response = client.get("/") + assert response.status_code == 200 def test_post_index(client: FlaskClient): - post_response = client.post("/") - assert post_response.status_code == 405 + response = client.post("/") + assert response.status_code == 405 diff --git a/tests/test_register.py b/tests/test_register.py new file mode 100644 index 0000000..4a15859 --- /dev/null +++ b/tests/test_register.py @@ -0,0 +1,130 @@ +from unittest.mock import Mock, patch +from flask_mail import Mail + +from flask.testing import FlaskClient +import pytest + +from app import db +from app.models import User + + +@pytest.fixture(scope='function') +def new_user_data(fake_user_client: User, fake_user_password: str) -> dict: + return { + "role": fake_user_client.role.value, + "first_name": fake_user_client.first_name, + "last_name": fake_user_client.last_name, + "email": "different-" + fake_user_client.email, + "password": fake_user_password, + } + + +def test_get_register(client: FlaskClient): + response = client.get("/register") + assert response.status_code == 200 + return + + +@patch.object(Mail, 'send') +def test_register_success(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): + + response = client.post("/register", data=new_user_data) + data = response.get_json() + + assert response.status_code == 200 + assert data["success"] is True + assert "url" in data + assert db.session.execute(db.select(User).filter_by(email=new_user_data["email"].lower())).scalar_one_or_none() is not None + mock_send_email.assert_called_once() + + return + + +@patch.object(Mail, 'send') +def test_register_missing_fields(mock_send_email: Mock, client: FlaskClient): + + initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + + response = client.post("/register", data={}) + data = response.get_json() + + assert response.status_code == 200 + assert data["success"] is False + assert "errors" in data + assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + mock_send_email.assert_not_called() + + return + + +@patch.object(Mail, 'send') +def test_register_invalid_role(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): + + invalid_user_data = new_user_data.copy() + invalid_user_data["role"] = "invalid_role" + initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + + response = client.post("/register", data=invalid_user_data) + data = response.get_json() + + assert response.status_code == 200 + assert "errors" in data + assert "role" in data["errors"] + assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + mock_send_email.assert_not_called() + + return + + +@patch.object(Mail, 'send') +def test_register_invalid_email(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): + + invalid_user_data = new_user_data.copy() + invalid_user_data["email"] = "invalidemail" + initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + + response = client.post("/register", data=invalid_user_data) + data = response.get_json() + + assert response.status_code == 200 + assert "errors" in data + assert "email" in data["errors"] + assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + mock_send_email.assert_not_called() + + return + + +@patch.object(Mail, 'send') +def test_register_duplicate_email(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): + + initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + + response = client.post("/register", data=new_user_data) + data = response.get_json() + + assert response.status_code == 200 + assert "errors" in data + assert "email" in data["errors"] + assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + mock_send_email.assert_not_called() + + return + +@patch.object(Mail, 'send') +def test_register_weak_password(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): + + invalid_user_data = new_user_data.copy() + invalid_user_data["password"] = "123" + initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + + response = client.post("/register", data=invalid_user_data) + data = response.get_json() + + assert response.status_code == 200 + assert "errors" in data + assert "password" in data["errors"] + assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + mock_send_email.assert_not_called() + + return From cc5d2b3541f0f1e0c1eb0f452e1a56217c10bbba Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sun, 25 Feb 2024 19:54:51 +0000 Subject: [PATCH 4/6] Added linting with flake8 --- .flake8 | 2 + Makefile | 23 ++++-- app/__init__.py | 2 + app/config.py | 4 +- app/models.py | 22 +++--- app/utils/validators.py | 1 + app/views/auth.py | 153 +++++++++++++++++++--------------------- requirements.txt | 12 +++- tests/test_register.py | 4 +- 9 files changed, 120 insertions(+), 103 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..16520fc --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +ignore = E501 \ No newline at end of file diff --git a/Makefile b/Makefile index 686ccd0..9fbf6e0 100644 --- a/Makefile +++ b/Makefile @@ -1,21 +1,32 @@ venv: - @echo "\nCreating virtual environment..." + @echo "Creating virtual environment..." python3 -m venv .venv - @echo "\nNote: You will need to activate the virtual environment in your shell manually using:" + @echo "Note: You will need to activate the virtual environment in your shell manually using:" @echo "source .venv/bin/activate" deps: - @echo "\nInstalling dependencies..." + @echo "Installing dependencies..." pip install -r requirements.txt app: - @echo "\nRunning Flask app locally..." + @echo "Running Flask app locally..." flask run clean: - @echo "\nCleaning up directory..." + @echo "Cleaning up directory..." rm -rf .venv find . -type d -name '__pycache__' -exec rm -r {} + find . -type f -name '*.pyc' -delete -.PHONY: venv deps app clean \ No newline at end of file +lint: + @echo "Linting Python files with flake8..." + flake8 --exclude .venv,./migrations + +test: + @echo "Running tests with pytest..." + pytest -s + +help: + @echo "Available commands: make [help, venv, deps, app, test, clean]" + +.PHONY: help venv deps app lint test clean \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py index d995fc3..ccd430e 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -22,7 +22,9 @@ selected_config = CONFIGS[os.environ["ENV"]] + def create_app(config: Config = selected_config): + app = Flask(__name__) app.config.from_object(config) diff --git a/app/config.py b/app/config.py index a653733..5a432f1 100644 --- a/app/config.py +++ b/app/config.py @@ -24,7 +24,9 @@ class Config(object): class DevConfig(Config): DEBUG: bool = True - SQLALCHEMY_DATABASE_URI: str = "sqlite:///" + os.path.join(basedir, "mindli.sqlite") + SQLALCHEMY_DATABASE_URI: str = ( + "sqlite:///" + os.path.join(basedir, "mindli.sqlite") + ) class ProdConfig(Config): diff --git a/app/models.py b/app/models.py index 8f766f4..afc7815 100644 --- a/app/models.py +++ b/app/models.py @@ -12,7 +12,9 @@ @login_manager.user_loader def load_user(user_id: str): - return db.session.execute(db.select(User).filter_by(id=int(user_id))).scalar_one() + return db.session.execute( + db.select(User).filter_by(id=int(user_id)) + ).scalar_one() @unique @@ -85,7 +87,7 @@ class User(UserMixin, db.Model): photo_url: so.Mapped[Optional[str]] = so.mapped_column(sa.String(255)) timezone: so.Mapped[Optional[str]] = so.mapped_column(sa.String(50)) # IANA Time Zone Database name currency: so.Mapped[Optional[str]] = so.mapped_column(sa.String(3)) # ISO 4217 currency code - + client: so.Mapped[Optional["Client"]] = so.relationship(back_populates="user") therapist: so.Mapped[Optional["Therapist"]] = so.relationship(back_populates="user") @@ -95,7 +97,7 @@ class Client(db.Model): user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey('user.id'), index=True) preferred_gender: so.Mapped[Optional["Gender"]] = so.mapped_column(sa.Enum(Gender)) preferred_language_id: so.Mapped[Optional[int]] = so.mapped_column(sa.ForeignKey('language.id')) - + user: so.Mapped["User"] = so.relationship(back_populates="client") issues: so.Mapped[List["Issue"]] = so.relationship(secondary=client_issue, back_populates="clients") preferred_language: so.Mapped[Optional["Language"]] = so.relationship("Language") @@ -112,7 +114,7 @@ class Therapist(db.Model): registrations: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) qualifications: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) years_of_experience: so.Mapped[Optional[int]] = so.mapped_column(sa.Integer) - + user: so.Mapped["User"] = so.relationship(back_populates="therapist") languages: so.Mapped[List["Language"]] = so.relationship(secondary=therapist_language, back_populates="therapists") specialisations: so.Mapped[List["Issue"]] = so.relationship(secondary=therapist_issue, back_populates="therapists") @@ -127,14 +129,14 @@ class Language(db.Model): name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) iso639_1: so.Mapped[Optional[str]] = so.mapped_column(sa.String(2), unique=True) # ISO 639-1 two-letter code iso639_2: so.Mapped[Optional[str]] = so.mapped_column(sa.String(3), unique=True) # ISO 639-2 three-letter code - + therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_language, back_populates="languages") class Issue(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) - + clients: so.Mapped[List["Client"]] = so.relationship(secondary=client_issue, back_populates="issues") therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_issue, back_populates="specialisations") @@ -142,7 +144,7 @@ class Issue(db.Model): class Intervention(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) - + therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_intervention, back_populates="interventions") @@ -155,7 +157,7 @@ class SessionType(db.Model): fee_currency: so.Mapped[str] = so.mapped_column(sa.String(3)) session_format: so.Mapped[Optional["SessionFormat"]] = so.mapped_column(sa.Enum(SessionFormat)) notes: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) - + therapist: so.Mapped["Therapist"] = so.relationship(back_populates="session_types") @@ -166,7 +168,7 @@ class Availability(db.Model): start_time: so.Mapped[Optional[time]] = so.mapped_column(sa.Time) end_time: so.Mapped[Optional[time]] = so.mapped_column(sa.Time) specific_date: so.Mapped[Optional[date]] = so.mapped_column(sa.Date) # For non-recurring availability - + therapist: so.Mapped["Therapist"] = so.relationship(back_populates="availabilities") @@ -176,7 +178,7 @@ class Unavailability(db.Model): start_date: so.Mapped[date] = so.mapped_column(sa.Date) end_date: so.Mapped[date] = so.mapped_column(sa.Date) reason: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) - + therapist: so.Mapped["Therapist"] = so.relationship(back_populates="unavailabilities") diff --git a/app/utils/validators.py b/app/utils/validators.py index 43ec233..9307bc9 100644 --- a/app/utils/validators.py +++ b/app/utils/validators.py @@ -1,5 +1,6 @@ import re + def isValidText(text: str) -> bool: return text and not text.isspace() diff --git a/app/views/auth.py b/app/views/auth.py index 45ff3b5..699825f 100644 --- a/app/views/auth.py +++ b/app/views/auth.py @@ -30,8 +30,11 @@ def logout() -> Response: @bp.route("/register", methods=["GET", "POST"]) def register() -> Response: - if request.method == "POST": - + if request.method == "GET": + logout_user() + return render_template("register.html") + + else: errors = {} # Get form data @@ -40,7 +43,6 @@ def register() -> Response: last_name = request.form.get("last_name") email = request.form.get("email") password = request.form.get("password") - # Validate input if not role or (role not in set(r.value for r in UserRole)): @@ -58,48 +60,44 @@ def register() -> Response: if errors: return jsonify({"success": False, "errors": errors}) - # Proceed with successful registration - else: - # Insert new user into database - email = email.lower() - user = User( - email=email.lower(), - password_hash=generate_password_hash(password), - first_name=first_name.capitalize(), - last_name=last_name.capitalize(), - date_joined=date.today(), - role=UserRole(role), - verified=False, - active=True, - ) - db.session.add(user) - db.session.commit() - - # Send verification email - email_message = EmailMessage( - mail=mail, - subject=EmailSubject.EMAIL_VERIFICATION, - recipient=user, - serialiser=current_app.serialiser, - ) - email_message.send() - - # Store email in session for email verification - session["email"] = email - return jsonify({"success": True, "url": url_for("auth.verify_email")}) + # Insert new user into database + email = email.lower() + user = User( + email=email.lower(), + password_hash=generate_password_hash(password), + first_name=first_name.capitalize(), + last_name=last_name.capitalize(), + date_joined=date.today(), + role=UserRole(role), + verified=False, + active=True, + ) + db.session.add(user) + db.session.commit() - # Request method is GET - else: - logout_user() - return render_template("register.html") + # Send verification email + email_message = EmailMessage( + mail=mail, + subject=EmailSubject.EMAIL_VERIFICATION, + recipient=user, + serialiser=current_app.serialiser, + ) + email_message.send() + + # Store email in session for email verification + session["email"] = email + return jsonify({"success": True, "url": url_for("auth.verify_email")}) # Logs user in if credentials are valid @bp.route("/login", methods=["GET", "POST"]) def login() -> Response: - if request.method == "POST": - + if request.method == "GET": + logout_user() + return render_template("login.html") + + elif request.method == "POST": errors = {} # Get form data @@ -128,11 +126,6 @@ def login() -> Response: login_user(user) return jsonify({"success": True, "url": url_for("main.index")}) - # Request method is GET - else: - logout_user() - return render_template("login.html") - # Displays page with email verification instructions, sends verification email @bp.route("/verify-email", methods=["GET", "POST"]) @@ -148,8 +141,11 @@ def verify_email() -> Response: if not user or user.verified: return redirect(url_for("main.index")) - # Sends verification email to user (POST used to utilise AJAX) - if request.method == "POST": + if request.method == "GET": + return render_template("verify-email.html", email=session["email"]) + + # Send verification email to user + elif request.method == "POST": email_message = EmailMessage( mail=mail, subject=EmailSubject.EMAIL_VERIFICATION, @@ -159,13 +155,11 @@ def verify_email() -> Response: email_message.send() flash(f"Email verification instructions sent to {user.email}") return jsonify({"success": True, "url": url_for("main.index")}) - - else: - return render_template("verify-email.html", email=session["email"]) + # Handles email verification using token -@bp.route("/email-verification/") +@bp.route("/email-verification/", methods=["GET"]) def email_verification(token): # Get email from token @@ -196,8 +190,11 @@ def email_verification(token): # Handles password resets by sending emails and updating the database @bp.route("/reset-password", methods=["GET", "POST"]) def reset_request() -> Response: - if request.method == "POST": - + + if request.method == "GET": + return render_template("initiate-password-reset.html") + + elif request.method == "POST": errors = {} # Form submitted to initiate a password reset @@ -209,25 +206,23 @@ def reset_request() -> Response: # Find user with this email user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() - # Check if user with this email does not exist + # Validate input if user is None: errors["email"] = "No account found with this email address" - - # Return errors if any if errors: return jsonify({"success": False, "errors": errors}) - # Send reset email - else: - email_message = EmailMessage( - mail=mail, - subject=EmailSubject.PASSWORD_RESET, - recipient=user, - serialiser=current_app.serialiser, - ) - email_message.send() - flash(f"Password reset instructions sent to {email}") - return jsonify({"success": True, "url": url_for("main.index")}) + # Send email with instructions + email_message = EmailMessage( + mail=mail, + subject=EmailSubject.PASSWORD_RESET, + recipient=user, + serialiser=current_app.serialiser, + ) + email_message.send() + + flash(f"Password reset instructions sent to {email}") + return jsonify({"success": True, "url": url_for("main.index")}) # Form submitted to reset password elif request.form.get("form-type") == "reset_password": @@ -247,25 +242,20 @@ def reset_request() -> Response: if errors: return jsonify({"success": False, "errors": errors}) - # Successful reset - else: - # Update user's password in database - user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() - user.password_hash = generate_password_hash(password) - db.session.commit() - - # Redirect to login page - flash("Success! Your password has been reset") - return jsonify({"success": True, "url": url_for("main.index")}) - - # Request method is GET - else: - return render_template("initiate-password-reset.html") + # Update user's password in database + user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() + user.password_hash = generate_password_hash(password) + db.session.commit() + + # Redirect to login page + flash("Success! Your password has been reset") + return jsonify({"success": True, "url": url_for("main.index")}) # Displays page to update password -@bp.route("/reset-password/") +@bp.route("/reset-password/", methods=["GET"]) def reset_password(token): + # Get email from token try: email = current_app.serialiser.loads( @@ -275,5 +265,6 @@ def reset_password(token): # Invalid/expired token except (BadSignature, SignatureExpired): - flash("Invalid or expired reset link, " "please request another password reset") + flash("Invalid or expired reset link, " + "please request another password reset") return redirect(url_for("main.index")) diff --git a/requirements.txt b/requirements.txt index 0c63cf0..e3ed2dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,18 @@ alembic==1.12.1 blinker==1.6.3 click==8.1.7 +dnspython==2.3.0 exceptiongroup==1.2.0 +flake8==5.0.4 Flask==2.2.5 Flask-Login==0.6.3 Flask-Mail==0.9.1 Flask-Migrate==4.0.5 -flask-sqlalchemy==3.0.5 +Flask-SQLAlchemy==3.0.5 Flask-WTF==1.1.1 greenlet==3.0.3 -importlib-metadata==6.7.0 +idna==3.6 +importlib-metadata==4.2.0 importlib-resources==5.12.0 iniconfig==2.0.0 isort==5.11.5 @@ -17,18 +20,21 @@ itsdangerous==2.1.2 Jinja2==3.1.2 Mako==1.2.4 MarkupSafe==2.1.3 +mccabe==0.7.0 mypy-extensions==1.0.0 packaging==23.2 pathspec==0.11.2 platformdirs==4.0.0 pluggy==1.2.0 +pycodestyle==2.9.1 +pyflakes==2.5.0 pytest==7.4.4 pytest-dotenv==0.5.2 python-dotenv==0.21.1 SQLAlchemy==2.0.23 tomli==2.0.1 typed-ast==1.5.5 -typing-extensions==4.7.1 +typing_extensions==4.7.1 Werkzeug==2.2.3 WTForms==3.0.1 zipp==3.15.0 diff --git a/tests/test_register.py b/tests/test_register.py index 4a15859..bd8efc3 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -1,8 +1,8 @@ from unittest.mock import Mock, patch -from flask_mail import Mail -from flask.testing import FlaskClient import pytest +from flask.testing import FlaskClient +from flask_mail import Mail from app import db from app.models import User From ce7921601a562f0e1c33bbc96a5f835591f82882 Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sun, 25 Feb 2024 20:23:06 +0000 Subject: [PATCH 5/6] flake8 and black --- .flake8 | 6 ++- Makefile | 6 ++- app/__init__.py | 2 +- app/config.py | 4 +- app/models.py | 102 +++++++++++++++++++++++++++++----------- app/utils/mail.py | 2 - app/utils/validators.py | 6 ++- app/views/auth.py | 73 ++++++++++++++++------------ requirements.txt | 1 + tests/conftest.py | 43 ++++++++--------- tests/test_login.py | 57 ++++++++++++++-------- tests/test_register.py | 97 ++++++++++++++++++++++++++------------ 12 files changed, 258 insertions(+), 141 deletions(-) diff --git a/.flake8 b/.flake8 index 16520fc..3f9d140 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,6 @@ [flake8] -ignore = E501 \ No newline at end of file +ignore = E501, W503 +max-line-length = 88 +exclude = + .venv, + migrations \ No newline at end of file diff --git a/Makefile b/Makefile index 9fbf6e0..3994993 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,11 @@ clean: find . -type f -name '*.pyc' -delete lint: - @echo "Linting Python files with flake8..." + @echo "Reorganising imports..." + isort . + @echo "Formatting Python files..." + black . --exclude '/(\.venv|migrations)/' + @echo "Linting Python files..." flake8 --exclude .venv,./migrations test: diff --git a/app/__init__.py b/app/__init__.py index ccd430e..a10ac49 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -24,7 +24,6 @@ def create_app(config: Config = selected_config): - app = Flask(__name__) app.config.from_object(config) @@ -37,6 +36,7 @@ def create_app(config: Config = selected_config): # Register blueprints from app.views import auth, main + app.register_blueprint(main.bp) app.register_blueprint(auth.bp) diff --git a/app/config.py b/app/config.py index 5a432f1..a653733 100644 --- a/app/config.py +++ b/app/config.py @@ -24,9 +24,7 @@ class Config(object): class DevConfig(Config): DEBUG: bool = True - SQLALCHEMY_DATABASE_URI: str = ( - "sqlite:///" + os.path.join(basedir, "mindli.sqlite") - ) + SQLALCHEMY_DATABASE_URI: str = "sqlite:///" + os.path.join(basedir, "mindli.sqlite") class ProdConfig(Config): diff --git a/app/models.py b/app/models.py index afc7815..f03649a 100644 --- a/app/models.py +++ b/app/models.py @@ -12,9 +12,7 @@ @login_manager.user_loader def load_user(user_id: str): - return db.session.execute( - db.select(User).filter_by(id=int(user_id)) - ).scalar_one() + return db.session.execute(db.select(User).filter_by(id=int(user_id))).scalar_one() @unique @@ -85,8 +83,12 @@ class User(UserMixin, db.Model): active: so.Mapped[bool] = so.mapped_column(sa.Boolean, default=True) gender: so.Mapped[Optional["Gender"]] = so.mapped_column(sa.Enum(Gender)) photo_url: so.Mapped[Optional[str]] = so.mapped_column(sa.String(255)) - timezone: so.Mapped[Optional[str]] = so.mapped_column(sa.String(50)) # IANA Time Zone Database name - currency: so.Mapped[Optional[str]] = so.mapped_column(sa.String(3)) # ISO 4217 currency code + timezone: so.Mapped[Optional[str]] = so.mapped_column( + sa.String(50) + ) # IANA Time Zone Database name + currency: so.Mapped[Optional[str]] = so.mapped_column( + sa.String(3) + ) # ISO 4217 currency code client: so.Mapped[Optional["Client"]] = so.relationship(back_populates="user") therapist: so.Mapped[Optional["Therapist"]] = so.relationship(back_populates="user") @@ -94,12 +96,16 @@ class User(UserMixin, db.Model): class Client(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) - user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey('user.id'), index=True) + user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey("user.id"), index=True) preferred_gender: so.Mapped[Optional["Gender"]] = so.mapped_column(sa.Enum(Gender)) - preferred_language_id: so.Mapped[Optional[int]] = so.mapped_column(sa.ForeignKey('language.id')) + preferred_language_id: so.Mapped[Optional[int]] = so.mapped_column( + sa.ForeignKey("language.id") + ) user: so.Mapped["User"] = so.relationship(back_populates="client") - issues: so.Mapped[List["Issue"]] = so.relationship(secondary=client_issue, back_populates="clients") + issues: so.Mapped[List["Issue"]] = so.relationship( + secondary=client_issue, back_populates="clients" + ) preferred_language: so.Mapped[Optional["Language"]] = so.relationship("Language") @@ -116,46 +122,76 @@ class Therapist(db.Model): years_of_experience: so.Mapped[Optional[int]] = so.mapped_column(sa.Integer) user: so.Mapped["User"] = so.relationship(back_populates="therapist") - languages: so.Mapped[List["Language"]] = so.relationship(secondary=therapist_language, back_populates="therapists") - specialisations: so.Mapped[List["Issue"]] = so.relationship(secondary=therapist_issue, back_populates="therapists") - interventions: so.Mapped[List["Intervention"]] = so.relationship(secondary=therapist_intervention, back_populates="therapists") - session_types: so.Mapped[List["SessionType"]] = so.relationship(back_populates="therapist") - availabilities: so.Mapped[List["Availability"]] = so.relationship(back_populates="therapist") - unavailabilities: so.Mapped[List["Unavailability"]] = so.relationship(back_populates="therapist") + languages: so.Mapped[List["Language"]] = so.relationship( + secondary=therapist_language, back_populates="therapists" + ) + specialisations: so.Mapped[List["Issue"]] = so.relationship( + secondary=therapist_issue, back_populates="therapists" + ) + interventions: so.Mapped[List["Intervention"]] = so.relationship( + secondary=therapist_intervention, back_populates="therapists" + ) + session_types: so.Mapped[List["SessionType"]] = so.relationship( + back_populates="therapist" + ) + availabilities: so.Mapped[List["Availability"]] = so.relationship( + back_populates="therapist" + ) + unavailabilities: so.Mapped[List["Unavailability"]] = so.relationship( + back_populates="therapist" + ) class Language(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) - iso639_1: so.Mapped[Optional[str]] = so.mapped_column(sa.String(2), unique=True) # ISO 639-1 two-letter code - iso639_2: so.Mapped[Optional[str]] = so.mapped_column(sa.String(3), unique=True) # ISO 639-2 three-letter code + iso639_1: so.Mapped[Optional[str]] = so.mapped_column( + sa.String(2), unique=True + ) # ISO 639-1 two-letter code + iso639_2: so.Mapped[Optional[str]] = so.mapped_column( + sa.String(3), unique=True + ) # ISO 639-2 three-letter code - therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_language, back_populates="languages") + therapists: so.Mapped[List["Therapist"]] = so.relationship( + secondary=therapist_language, back_populates="languages" + ) class Issue(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) - clients: so.Mapped[List["Client"]] = so.relationship(secondary=client_issue, back_populates="issues") - therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_issue, back_populates="specialisations") + clients: so.Mapped[List["Client"]] = so.relationship( + secondary=client_issue, back_populates="issues" + ) + therapists: so.Mapped[List["Therapist"]] = so.relationship( + secondary=therapist_issue, back_populates="specialisations" + ) class Intervention(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(50), unique=True) - therapists: so.Mapped[List["Therapist"]] = so.relationship(secondary=therapist_intervention, back_populates="interventions") + therapists: so.Mapped[List["Therapist"]] = so.relationship( + secondary=therapist_intervention, back_populates="interventions" + ) class SessionType(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) - therapist_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey("therapist.id"), index=True) - name: so.Mapped[str] = so.mapped_column(sa.String(255)) # e.g. "Initial Consultation" + therapist_id: so.Mapped[int] = so.mapped_column( + sa.ForeignKey("therapist.id"), index=True + ) + name: so.Mapped[str] = so.mapped_column( + sa.String(255) + ) # e.g. "Initial Consultation" session_duration: so.Mapped[int] = so.mapped_column(sa.Integer) # In minutes fee_amount: so.Mapped[float] = so.mapped_column(sa.Float) fee_currency: so.Mapped[str] = so.mapped_column(sa.String(3)) - session_format: so.Mapped[Optional["SessionFormat"]] = so.mapped_column(sa.Enum(SessionFormat)) + session_format: so.Mapped[Optional["SessionFormat"]] = so.mapped_column( + sa.Enum(SessionFormat) + ) notes: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) therapist: so.Mapped["Therapist"] = so.relationship(back_populates="session_types") @@ -163,23 +199,33 @@ class SessionType(db.Model): class Availability(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) - therapist_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey("therapist.id"), index=True) - day_of_week: so.Mapped[Optional[int]] = so.mapped_column(sa.Integer) # 0=Monday, 6=Sunday, None for specific dates + therapist_id: so.Mapped[int] = so.mapped_column( + sa.ForeignKey("therapist.id"), index=True + ) + day_of_week: so.Mapped[Optional[int]] = so.mapped_column( + sa.Integer + ) # 0=Monday, 6=Sunday, None for specific dates start_time: so.Mapped[Optional[time]] = so.mapped_column(sa.Time) end_time: so.Mapped[Optional[time]] = so.mapped_column(sa.Time) - specific_date: so.Mapped[Optional[date]] = so.mapped_column(sa.Date) # For non-recurring availability + specific_date: so.Mapped[Optional[date]] = so.mapped_column( + sa.Date + ) # For non-recurring availability therapist: so.Mapped["Therapist"] = so.relationship(back_populates="availabilities") class Unavailability(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) - therapist_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey("therapist.id"), index=True) + therapist_id: so.Mapped[int] = so.mapped_column( + sa.ForeignKey("therapist.id"), index=True + ) start_date: so.Mapped[date] = so.mapped_column(sa.Date) end_date: so.Mapped[date] = so.mapped_column(sa.Date) reason: so.Mapped[Optional[str]] = so.mapped_column(sa.Text) - therapist: so.Mapped["Therapist"] = so.relationship(back_populates="unavailabilities") + therapist: so.Mapped["Therapist"] = so.relationship( + back_populates="unavailabilities" + ) def insertDummyData() -> None: diff --git a/app/utils/mail.py b/app/utils/mail.py index d0ba379..acfec3a 100644 --- a/app/utils/mail.py +++ b/app/utils/mail.py @@ -15,7 +15,6 @@ class EmailSubject(Enum): class EmailMessage: - def __init__( self, mail: Mail, @@ -23,7 +22,6 @@ def __init__( recipient: User, serialiser: Optional[URLSafeTimedSerializer], ) -> None: - self.mail = mail self.recipient = recipient self.subject = subject.value diff --git a/app/utils/validators.py b/app/utils/validators.py index 9307bc9..dbe8b9e 100644 --- a/app/utils/validators.py +++ b/app/utils/validators.py @@ -4,10 +4,12 @@ def isValidText(text: str) -> bool: return text and not text.isspace() + def isValidEmail(email: str) -> bool: - email_regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + email_regex = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" return email and re.match(email_regex, email) + def isValidPassword(password: str) -> bool: return ( password @@ -15,4 +17,4 @@ def isValidPassword(password: str) -> bool: and any(char.isdigit() for char in password) and any(char.isupper() for char in password) and any(char.islower() for char in password) - ) \ No newline at end of file + ) diff --git a/app/views/auth.py b/app/views/auth.py index 699825f..1b268d4 100644 --- a/app/views/auth.py +++ b/app/views/auth.py @@ -1,7 +1,17 @@ from datetime import date -from flask import (Blueprint, Response, current_app, flash, jsonify, redirect, - render_template, request, session, url_for) +from flask import ( + Blueprint, + Response, + current_app, + flash, + jsonify, + redirect, + render_template, + request, + session, + url_for, +) from flask_login import login_user, logout_user from itsdangerous import BadSignature, SignatureExpired from werkzeug.security import check_password_hash, generate_password_hash @@ -29,7 +39,6 @@ def logout() -> Response: @bp.route("/register", methods=["GET", "POST"]) def register() -> Response: - if request.method == "GET": logout_user() return render_template("register.html") @@ -43,7 +52,7 @@ def register() -> Response: last_name = request.form.get("last_name") email = request.form.get("email") password = request.form.get("password") - + # Validate input if not role or (role not in set(r.value for r in UserRole)): errors["role"] = "Account type is required" @@ -55,7 +64,12 @@ def register() -> Response: errors["password"] = "Password does not meet requirements" if not isValidEmail(email): errors["email"] = "Invalid email address" - elif db.session.execute(db.select(User).filter_by(email=email.lower())).scalar_one_or_none() is not None: + elif ( + db.session.execute( + db.select(User).filter_by(email=email.lower()) + ).scalar_one_or_none() + is not None + ): errors["email"] = "Email address is already in use" if errors: return jsonify({"success": False, "errors": errors}) @@ -83,7 +97,7 @@ def register() -> Response: serialiser=current_app.serialiser, ) email_message.send() - + # Store email in session for email verification session["email"] = email return jsonify({"success": True, "url": url_for("auth.verify_email")}) @@ -92,7 +106,6 @@ def register() -> Response: # Logs user in if credentials are valid @bp.route("/login", methods=["GET", "POST"]) def login() -> Response: - if request.method == "GET": logout_user() return render_template("login.html") @@ -110,7 +123,9 @@ def login() -> Response: if not password: errors["password"] = "Password is required" else: - user = db.session.execute(db.select(User).filter_by(email=email.lower())).scalar_one_or_none() + user = db.session.execute( + db.select(User).filter_by(email=email.lower()) + ).scalar_one_or_none() if user is None or not check_password_hash(user.password_hash, password): errors["password"] = "Incorrect email/password" if errors: @@ -130,10 +145,11 @@ def login() -> Response: # Displays page with email verification instructions, sends verification email @bp.route("/verify-email", methods=["GET", "POST"]) def verify_email() -> Response: - # Get user with email stored in session if "email" in session: - user = db.session.execute(db.select(User).filter_by(email=session["email"])).scalar_one_or_none() + user = db.session.execute( + db.select(User).filter_by(email=session["email"]) + ).scalar_one_or_none() else: user = None @@ -143,7 +159,7 @@ def verify_email() -> Response: if request.method == "GET": return render_template("verify-email.html", email=session["email"]) - + # Send verification email to user elif request.method == "POST": email_message = EmailMessage( @@ -155,13 +171,11 @@ def verify_email() -> Response: email_message.send() flash(f"Email verification instructions sent to {user.email}") return jsonify({"success": True, "url": url_for("main.index")}) - # Handles email verification using token @bp.route("/email-verification/", methods=["GET"]) def email_verification(token): - # Get email from token try: email = current_app.serialiser.loads( @@ -169,7 +183,9 @@ def email_verification(token): ) # Each token is valid for 5 days # Mark user as verified - user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() + user = db.session.execute( + db.select(User).filter_by(email=email) + ).scalar_one_or_none() user.verified = True db.session.commit() @@ -183,35 +199,35 @@ def email_verification(token): "Invalid or expired verification link, " "please log in to request a new link" ) - + return redirect(url_for("main.index")) # Handles password resets by sending emails and updating the database @bp.route("/reset-password", methods=["GET", "POST"]) def reset_request() -> Response: - if request.method == "GET": return render_template("initiate-password-reset.html") - - elif request.method == "POST": + + elif request.method == "POST": errors = {} - + # Form submitted to initiate a password reset if request.form.get("form-type") == "initiate_password_reset": - # Get form data email = request.form.get("email").lower() # Find user with this email - user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() + user = db.session.execute( + db.select(User).filter_by(email=email) + ).scalar_one_or_none() # Validate input if user is None: errors["email"] = "No account found with this email address" if errors: return jsonify({"success": False, "errors": errors}) - + # Send email with instructions email_message = EmailMessage( mail=mail, @@ -220,13 +236,12 @@ def reset_request() -> Response: serialiser=current_app.serialiser, ) email_message.send() - + flash(f"Password reset instructions sent to {email}") return jsonify({"success": True, "url": url_for("main.index")}) # Form submitted to reset password elif request.form.get("form-type") == "reset_password": - # Get form data email = request.form.get("email") password = request.form.get("password") @@ -241,9 +256,11 @@ def reset_request() -> Response: errors["password_confirmation"] = "Passwords do not match" if errors: return jsonify({"success": False, "errors": errors}) - + # Update user's password in database - user = db.session.execute(db.select(User).filter_by(email=email)).scalar_one_or_none() + user = db.session.execute( + db.select(User).filter_by(email=email) + ).scalar_one_or_none() user.password_hash = generate_password_hash(password) db.session.commit() @@ -255,7 +272,6 @@ def reset_request() -> Response: # Displays page to update password @bp.route("/reset-password/", methods=["GET"]) def reset_password(token): - # Get email from token try: email = current_app.serialiser.loads( @@ -265,6 +281,5 @@ def reset_password(token): # Invalid/expired token except (BadSignature, SignatureExpired): - flash("Invalid or expired reset link, " - "please request another password reset") + flash("Invalid or expired reset link, " "please request another password reset") return redirect(url_for("main.index")) diff --git a/requirements.txt b/requirements.txt index e3ed2dd..0f3f4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ alembic==1.12.1 +black==23.3.0 blinker==1.6.3 click==8.1.7 dnspython==2.3.0 diff --git a/tests/conftest.py b/tests/conftest.py index 344db66..44f982e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,52 +11,49 @@ from app.models import Gender, User, UserRole -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app() -> Generator[Flask, Any, None]: - app = create_app(config=TestConfig) - + with app.app_context(): db.create_all() - + yield app - + db.session.remove() db.drop_all() return -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def client(app: Flask) -> FlaskClient: return app.test_client() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def fake_user_password() -> str: - return 'ValidPassword1' + return "ValidPassword1" -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def fake_user_client(fake_user_password: str) -> Generator[User, Any, None]: - # Insert test data fake_user_client = User( - email="client@example.com".lower(), - password_hash=generate_password_hash(fake_user_password), - first_name="John", - last_name="Smith", - date_joined=date.today(), - role=UserRole.CLIENT, - verified=True, - active=True, - gender=Gender.MALE, - ) + email="client@example.com".lower(), + password_hash=generate_password_hash(fake_user_password), + first_name="John", + last_name="Smith", + date_joined=date.today(), + role=UserRole.CLIENT, + verified=True, + active=True, + gender=Gender.MALE, + ) db.session.add(fake_user_client) db.session.commit() - + yield fake_user_client - + db.session.delete(fake_user_client) db.session.commit() return - diff --git a/tests/test_login.py b/tests/test_login.py index 4974632..d7e51cc 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -6,7 +6,7 @@ from app.models import User -@patch('app.views.auth.login_user') +@patch("app.views.auth.login_user") def test_get_login(mock_login_user: Mock, client: FlaskClient): response = client.get("/login") assert response.status_code == 200 @@ -14,12 +14,20 @@ def test_get_login(mock_login_user: Mock, client: FlaskClient): return -@patch('app.views.auth.login_user') -def test_user_login_success(mock_login_user: Mock, client: FlaskClient, fake_user_client: User, fake_user_password: str): - response = client.post("/login", data={ - "email": fake_user_client.email, - "password": fake_user_password, - }) +@patch("app.views.auth.login_user") +def test_user_login_success( + mock_login_user: Mock, + client: FlaskClient, + fake_user_client: User, + fake_user_password: str, +): + response = client.post( + "/login", + data={ + "email": fake_user_client.email, + "password": fake_user_password, + }, + ) data = response.get_json() assert response.status_code == 200 assert data["success"] is True @@ -28,7 +36,7 @@ def test_user_login_success(mock_login_user: Mock, client: FlaskClient, fake_use return -@patch('app.views.auth.login_user') +@patch("app.views.auth.login_user") def test_user_login_missing_credentials(mock_login_user: Mock, client: FlaskClient): response = client.post("/login", data={}) data = response.get_json() @@ -40,12 +48,17 @@ def test_user_login_missing_credentials(mock_login_user: Mock, client: FlaskClie return -@patch('app.views.auth.login_user') -def test_user_login_wrong_credentials(mock_login_user: Mock, client: FlaskClient, fake_user_client: User): - response = client.post("/login", data={ - "email": fake_user_client.email, - "password": "wrongpassword", - }) +@patch("app.views.auth.login_user") +def test_user_login_wrong_credentials( + mock_login_user: Mock, client: FlaskClient, fake_user_client: User +): + response = client.post( + "/login", + data={ + "email": fake_user_client.email, + "password": "wrongpassword", + }, + ) data = response.get_json() assert response.status_code == 200 assert data["success"] is False @@ -55,14 +68,18 @@ def test_user_login_wrong_credentials(mock_login_user: Mock, client: FlaskClient return -@patch('app.views.auth.login_user') -def test_user_login_unverified(mock_login_user: Mock, client: FlaskClient, fake_user_client: User, fake_user_password: str): +@patch("app.views.auth.login_user") +def test_user_login_unverified( + mock_login_user: Mock, + client: FlaskClient, + fake_user_client: User, + fake_user_password: str, +): fake_user_client.verified = False db.session.commit() - response = client.post("/login", data={ - "email": fake_user_client.email, - "password": fake_user_password - }) + response = client.post( + "/login", data={"email": fake_user_client.email, "password": fake_user_password} + ) data = response.get_json() assert response.status_code == 200 assert data["success"] is True diff --git a/tests/test_register.py b/tests/test_register.py index bd8efc3..90bd004 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -8,7 +8,7 @@ from app.models import User -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def new_user_data(fake_user_client: User, fake_user_password: str) -> dict: return { "role": fake_user_client.role.value, @@ -25,44 +25,57 @@ def test_get_register(client: FlaskClient): return -@patch.object(Mail, 'send') -def test_register_success(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): - +@patch.object(Mail, "send") +def test_register_success( + mock_send_email: Mock, client: FlaskClient, new_user_data: dict +): response = client.post("/register", data=new_user_data) data = response.get_json() assert response.status_code == 200 assert data["success"] is True assert "url" in data - assert db.session.execute(db.select(User).filter_by(email=new_user_data["email"].lower())).scalar_one_or_none() is not None + assert ( + db.session.execute( + db.select(User).filter_by(email=new_user_data["email"].lower()) + ).scalar_one_or_none() + is not None + ) mock_send_email.assert_called_once() return -@patch.object(Mail, 'send') +@patch.object(Mail, "send") def test_register_missing_fields(mock_send_email: Mock, client: FlaskClient): + initial_user_count = db.session.execute( + db.select(db.func.count()).select_from(User) + ).scalar() - initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() - response = client.post("/register", data={}) data = response.get_json() - + assert response.status_code == 200 assert data["success"] is False assert "errors" in data - assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + assert ( + db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + == initial_user_count + ) mock_send_email.assert_not_called() - + return -@patch.object(Mail, 'send') -def test_register_invalid_role(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): - +@patch.object(Mail, "send") +def test_register_invalid_role( + mock_send_email: Mock, client: FlaskClient, new_user_data: dict +): invalid_user_data = new_user_data.copy() invalid_user_data["role"] = "invalid_role" - initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + initial_user_count = db.session.execute( + db.select(db.func.count()).select_from(User) + ).scalar() response = client.post("/register", data=invalid_user_data) data = response.get_json() @@ -70,18 +83,24 @@ def test_register_invalid_role(mock_send_email: Mock, client: FlaskClient, new_u assert response.status_code == 200 assert "errors" in data assert "role" in data["errors"] - assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + assert ( + db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + == initial_user_count + ) mock_send_email.assert_not_called() return -@patch.object(Mail, 'send') -def test_register_invalid_email(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): - +@patch.object(Mail, "send") +def test_register_invalid_email( + mock_send_email: Mock, client: FlaskClient, new_user_data: dict +): invalid_user_data = new_user_data.copy() invalid_user_data["email"] = "invalidemail" - initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + initial_user_count = db.session.execute( + db.select(db.func.count()).select_from(User) + ).scalar() response = client.post("/register", data=invalid_user_data) data = response.get_json() @@ -89,16 +108,22 @@ def test_register_invalid_email(mock_send_email: Mock, client: FlaskClient, new_ assert response.status_code == 200 assert "errors" in data assert "email" in data["errors"] - assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + assert ( + db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + == initial_user_count + ) mock_send_email.assert_not_called() return -@patch.object(Mail, 'send') -def test_register_duplicate_email(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): - - initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() +@patch.object(Mail, "send") +def test_register_duplicate_email( + mock_send_email: Mock, client: FlaskClient, new_user_data: dict +): + initial_user_count = db.session.execute( + db.select(db.func.count()).select_from(User) + ).scalar() response = client.post("/register", data=new_user_data) data = response.get_json() @@ -106,17 +131,24 @@ def test_register_duplicate_email(mock_send_email: Mock, client: FlaskClient, ne assert response.status_code == 200 assert "errors" in data assert "email" in data["errors"] - assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + assert ( + db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + == initial_user_count + ) mock_send_email.assert_not_called() return -@patch.object(Mail, 'send') -def test_register_weak_password(mock_send_email: Mock, client: FlaskClient, new_user_data: dict): - + +@patch.object(Mail, "send") +def test_register_weak_password( + mock_send_email: Mock, client: FlaskClient, new_user_data: dict +): invalid_user_data = new_user_data.copy() invalid_user_data["password"] = "123" - initial_user_count = db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + initial_user_count = db.session.execute( + db.select(db.func.count()).select_from(User) + ).scalar() response = client.post("/register", data=invalid_user_data) data = response.get_json() @@ -124,7 +156,10 @@ def test_register_weak_password(mock_send_email: Mock, client: FlaskClient, new_ assert response.status_code == 200 assert "errors" in data assert "password" in data["errors"] - assert db.session.execute(db.select(db.func.count()).select_from(User)).scalar() == initial_user_count + assert ( + db.session.execute(db.select(db.func.count()).select_from(User)).scalar() + == initial_user_count + ) mock_send_email.assert_not_called() return From 9b18be7dd0f9b3a2f1126838093cb607f4f58543 Mon Sep 17 00:00:00 2001 From: Neil Shaabi <66903165+neilshaabi@users.noreply.github.com> Date: Sun, 25 Feb 2024 20:32:21 +0000 Subject: [PATCH 6/6] Update test_register.py --- tests/test_register.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_register.py b/tests/test_register.py index 90bd004..4e5894b 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -19,9 +19,11 @@ def new_user_data(fake_user_client: User, fake_user_password: str) -> dict: } -def test_get_register(client: FlaskClient): +@patch.object(Mail, "send") +def test_get_register(mock_send_email: Mock, client: FlaskClient): response = client.get("/register") assert response.status_code == 200 + mock_send_email.assert_not_called() return