From ce5cb61d0f33d1e7cf95bfbe09c876b0f8a18b0a Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Thu, 12 Jan 2023 12:51:51 +0300 Subject: [PATCH 01/17] improve Application add owner change to only manager can create Application --- apihub/subscription/router.py | 13 +++++++++++-- apihub/subscription/schemas.py | 4 ++++ tests/test_subscription.py | 27 +++++++++++++-------------- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index 68d153a..cab7903 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -16,6 +16,7 @@ SubscriptionCreate, SubscriptionIn, ApplicationCreate, + ApplicationCreateWithOwner, ) from .queries import ( SubscriptionQuery, @@ -40,13 +41,21 @@ class SubscriptionSettings(BaseSettings): def create_application( application: ApplicationCreate, session: Session = Depends(create_session), - username: str = Depends(require_admin), + user: UserBase = Depends(require_token), ): """ Create an application. """ + applicationCreateWithOwner = ApplicationCreateWithOwner.copy( + application, + update={"owner": user.username} + ) + + if not user.is_manager: + raise HTTPException(401, "Only developers can create applications.") + try: - return ApplicationQuery(session).create_application(application) + return ApplicationQuery(session).create_application(applicationCreateWithOwner) except ApplicationException as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index 8e662b8..d0573a8 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -34,6 +34,10 @@ class ApplicationCreate(ApplicationBase): pricing: List[SubscriptionPricingBase] +class ApplicationCreateWithOwner(ApplicationCreate): + owner: str + + class SubscriptionBase(BaseModel): username: str application: str diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 723ac92..91ceb22 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -86,6 +86,18 @@ class Meta: notes = None +def _require_admin_token(): + return UserBase(username="tester", role=UserType.ADMIN) + + +def _require_user_token(): + return UserBase(username="tester", role=UserType.USER) + + +def _require_manager_token(): + return UserBase(username="tester", role=UserType.MANAGER) + + @pytest.fixture(scope="function") def client(db_session): def _create_session(): @@ -100,12 +112,6 @@ def _require_admin(): def _require_user(): return "user" - def _require_admin_token(): - return UserBase(username="tester", role=UserType.ADMIN) - - def _require_user_token(): - return UserBase(username="tester", role=UserType.USER) - app = FastAPI() app.include_router(router) @@ -145,16 +151,9 @@ def api_function_2( yield TestClient(app) -def _require_admin_token(): - return UserBase(username="tester", role=UserType.ADMIN) - - -def _require_user_token(): - return UserBase(username="tester", role=UserType.USER) - - class TestApplication: def test_create_application(self, client): + client.app.dependency_overrides[require_token] = _require_manager_token new_application = ApplicationCreate( name="app", url="/test", From 1c8346928c81318f4b3666968f6840032e30e7b8 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Thu, 19 Jan 2023 10:22:02 +0300 Subject: [PATCH 02/17] add owner to application --- apihub/security/helpers.py | 1 + apihub/subscription/models.py | 3 +++ apihub/subscription/queries.py | 4 +++- tests/test_subscription.py | 3 +++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/apihub/security/helpers.py b/apihub/security/helpers.py index 1907a00..8ef1863 100644 --- a/apihub/security/helpers.py +++ b/apihub/security/helpers.py @@ -20,5 +20,6 @@ def hash_password(password, salt=None): password.encode("utf-8"), salt_, 100000, + dklen=64, ).hex() return salt, hashed_password diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index 23b3619..95d62c2 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -28,6 +28,9 @@ class Application(Base): url = Column(String) description = Column(String) + created_at = Column(DateTime, default=datetime.now()) + owner = Column(String, ForeignKey("users.username")) + subscriptions = relationship("Subscription", backref="app") subscriptions_pricing = relationship("SubscriptionPricing", backref="app") diff --git a/apihub/subscription/queries.py b/apihub/subscription/queries.py index 2f3656d..cbfdd4d 100644 --- a/apihub/subscription/queries.py +++ b/apihub/subscription/queries.py @@ -13,6 +13,7 @@ SubscriptionCreate, SubscriptionDetails, ApplicationCreate, + ApplicationCreateWithOwner, SubscriptionPricingBase, ) from .helpers import get_and_reset_balance_in_cache @@ -38,7 +39,7 @@ def get_query(self) -> Query: """ return self.session.query(Application) - def create_application(self, application: ApplicationCreate): + def create_application(self, application: ApplicationCreateWithOwner): """ Create an application. :param application: Application details. @@ -58,6 +59,7 @@ def create_application(self, application: ApplicationCreate): name=application.name, url=application.url, description=application.description, + owner=application.owner, subscriptions_pricing=pricing_list, ) try: diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 91ceb22..2f2924e 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -43,6 +43,9 @@ class Meta: url = factory.Sequence(lambda n: f"app/{n}") description = "description" + created_at = factory.LazyFunction(datetime.now) + owner = "tester" + class SubscriptionPricingFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: From 90c458c0f0e3f0273cc87905c81d25d3f4ea2fd3 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Thu, 19 Jan 2023 10:32:14 +0300 Subject: [PATCH 03/17] Add alemic to manage data migration --- ...dd_owner_and_created_at_to_application_.py | 26 +++++ alembic.ini | 105 ++++++++++++++++++ alembic/README | 1 + alembic/env.py | 81 ++++++++++++++ alembic/script.py.mako | 24 ++++ ...a4422d71_add_owner_to_application_table.py | 27 +++++ poetry.lock | 44 +++++++- pyproject.toml | 5 +- 8 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 9be9edae04c5_add_owner_and_created_at_to_application_.py create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/5012a4422d71_add_owner_to_application_table.py diff --git a/9be9edae04c5_add_owner_and_created_at_to_application_.py b/9be9edae04c5_add_owner_and_created_at_to_application_.py new file mode 100644 index 0000000..c859341 --- /dev/null +++ b/9be9edae04c5_add_owner_and_created_at_to_application_.py @@ -0,0 +1,26 @@ +"""add owner and created_at to application table + +Revision ID: 9be9edae04c5 +Revises: +Create Date: 2023-01-12 15:01:46.535112 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9be9edae04c5' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('application', sa.Column("created_at", sa.DateTime, nullable=False)) + op.add_column('application', sa.Column("owner", sa.ForeignKey('user.username'), nullable=False)) + + +def downgrade() -> None: + op.drop_column('application', "owner") + op.drop_column('application', "created_at") \ No newline at end of file diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..86d1800 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..c9e8f11 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,81 @@ +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +config.set_main_option('sqlalchemy.url', os.environ.get('DB_URI')) + +# add your model's MetaData object here +# for 'autogenerate' support +import apihub.server +from apihub.common.db_session import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..55df286 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/5012a4422d71_add_owner_to_application_table.py b/alembic/versions/5012a4422d71_add_owner_to_application_table.py new file mode 100644 index 0000000..77c6fc5 --- /dev/null +++ b/alembic/versions/5012a4422d71_add_owner_to_application_table.py @@ -0,0 +1,27 @@ +"""Add owner to application table + +Revision ID: 5012a4422d71 +Revises: +Create Date: 2023-01-15 10:55:32.594417 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5012a4422d71' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('application', sa.Column('created_at', sa.DateTime(), nullable=True)) + op.add_column('application', sa.Column('owner', sa.String(), nullable=True)) + op.create_foreign_key(None, 'application', 'users', ['owner'], ['username']) + +def downgrade() -> None: + op.drop_constraint(None, 'application', type_='foreignkey') + op.drop_column('application', 'owner') + op.drop_column('application', 'created_at') diff --git a/poetry.lock b/poetry.lock index 3f08a86..62c8659 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,26 @@ # This file is automatically @generated by Poetry and should not be changed by hand. +[[package]] +name = "alembic" +version = "1.9.1" +description = "A database migration tool for SQLAlchemy." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "alembic-1.9.1-py3-none-any.whl", hash = "sha256:a9781ed0979a20341c2cbb56bd22bd8db4fc1913f955e705444bd3a97c59fa32"}, + {file = "alembic-1.9.1.tar.gz", hash = "sha256:f9f76e41061f5ebe27d4fe92600df9dd612521a7683f904dab328ba02cffa5a2"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.9\""} +importlib-resources = {version = "*", markers = "python_version < \"3.9\""} +Mako = "*" +SQLAlchemy = ">=1.3.0" + +[package.extras] +tz = ["python-dateutil"] + [[package]] name = "async-timeout" version = "4.0.2" @@ -967,6 +988,27 @@ roundrobin = ">=0.0.2" typing-extensions = ">=3.7.4.3" Werkzeug = ">=2.0.0" +[[package]] +name = "mako" +version = "1.2.4" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"}, + {file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "markupsafe" version = "2.1.1" @@ -2420,4 +2462,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "eec650345fe015d0cafb0ec7b621bea41f0871ab93f6d742fb8f8d76b2ac1593" +content-hash = "f217b8aa99c8e57d4f8a77613ff54337e5f60037383521536d00892da025368e" diff --git a/pyproject.toml b/pyproject.toml index 0504ca5..446a60b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ apihub_admin = "apihub.admin:create_all_statements" [tool.poetry.group.dev.dependencies] openapi-spec-validator = "^0.5.1" +alembic = "^1.9.1" [tool.black] line-length = 88 @@ -42,7 +43,7 @@ select = "B,C,E,F,W,T4,B9" [tool.pytest.ini_options] minversion = "6.0" python_files = "*.py" -addopts = "--color=yes -p no:warnings --ignore=templates --ignore=static --doctest-glob *.rst --ignore=apihub" +addopts = "--color=yes -p no:warnings --ignore=templates --ignore=static --doctest-glob *.rst --ignore=apihub --ignore=alembic" doctest_optionflags= "NORMALIZE_WHITESPACE ELLIPSIS" [tool.poetry.dependencies] @@ -74,4 +75,4 @@ locust = "^2.13.0" [build-system] requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +build-backend = "poetry.core.masonry.api" \ No newline at end of file From d3fb298b4da32e394290413a439d453d1c1d6ee1 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Mon, 23 Jan 2023 11:03:36 +0300 Subject: [PATCH 04/17] refact: manager -> publisher --- apihub/admin.py | 50 +++++++++++++++++++++++++++++++++++ apihub/security/depends.py | 12 ++++----- apihub/security/router.py | 12 ++++----- apihub/security/schemas.py | 6 ++--- apihub/subscription/router.py | 9 +++---- tests/test_security.py | 16 +++++------ tests/test_subscription.py | 13 ++++++--- 7 files changed, 85 insertions(+), 33 deletions(-) create mode 100644 apihub/admin.py diff --git a/apihub/admin.py b/apihub/admin.py new file mode 100644 index 0000000..3ccd55e --- /dev/null +++ b/apihub/admin.py @@ -0,0 +1,50 @@ +import sys + +from pydantic import BaseSettings + +from apihub.common.db_session import db_context, Base, DB_ENGINE +from apihub.common.redis_session import redis_conn +from apihub.security.schemas import UserCreate, UserType +from apihub.security.queries import UserQuery +from apihub.subscription.queries import SubscriptionQuery + + +class SuperUser(BaseSettings): + username: str + password: str + email: str + + def as_usercreate(self): + return UserCreate( + username=self.username, + password=self.password, + email=self.email, + role=UserType.ADMIN, + ) + + +def init(): + Base.metadata.bind = DB_ENGINE + Base.metadata.create_all() + + with db_context() as session: + user = SuperUser().as_usercreate() + UserQuery(session).create_user(user) + sys.stderr.write(f"Admin {user.username} is created!") + + +def create_all_statements(): + from sqlalchemy import create_mock_engine + + def metadata_dump(sql, *multiparams, **params): + print(sql.compile(dialect=engine.dialect)) + + engine = create_mock_engine("postgresql://", metadata_dump) + Base.metadata.bind = engine + Base.metadata.create_all(engine, checkfirst=False) + + +def deinit(): + Base.metadata.bind = DB_ENGINE + Base.metadata.drop_all() + sys.stderr.write("deinit is done!") diff --git a/apihub/security/depends.py b/apihub/security/depends.py index 01b62de..fd59b38 100644 --- a/apihub/security/depends.py +++ b/apihub/security/depends.py @@ -52,8 +52,8 @@ def __init__(self, role: Optional[str] = None, roles: List[str] = list()): def __call__(self, Authorize: AuthJWT = Depends()): Authorize.jwt_required() - roles = Authorize.get_raw_jwt().get("roles", {}) - if any(role in roles for role in self.roles): + role = Authorize.get_raw_jwt().get("role", "") + if role in self.roles: username = Authorize.get_jwt_subject() return username @@ -65,16 +65,16 @@ def __call__(self, Authorize: AuthJWT = Depends()): def require_token(Authorize: AuthJWT = Depends()) -> UserBase: Authorize.jwt_required() - roles = Authorize.get_raw_jwt()["roles"] + role = Authorize.get_raw_jwt()["role"] username = Authorize.get_jwt_subject() return UserBase( username=username, - role=roles[0], + role=role, ) require_admin = UserOfRole(role="admin") -require_manager = UserOfRole(role="manager") +require_publisher = UserOfRole(role="publisher") require_user = UserOfRole(role="user") require_app = UserOfRole(role="app") -require_manager_or_admin = UserOfRole(roles=["admin", "manager"]) +require_publisher_or_admin = UserOfRole(roles=["admin", "publisher"]) diff --git a/apihub/security/router.py b/apihub/security/router.py index e52038a..ea1e634 100644 --- a/apihub/security/router.py +++ b/apihub/security/router.py @@ -29,7 +29,7 @@ def get_config(): class AuthenticateResponse(BaseModel): username: str - roles: List[str] + role: str access_token: str expires_time: int @@ -49,8 +49,6 @@ async def _authenticate( except UserException: raise HTTPException(HTTP_403_FORBIDDEN, "User not found or wrong password") - roles = [user.role] - # make sure the max expires_days won't exceed setting if expires_days > SecuritySettings().security_token_expires_time: expires_days = SecuritySettings().security_token_expires_time @@ -59,12 +57,12 @@ async def _authenticate( expires_time = datetime.timedelta(days=expires_days) access_token = Authorize.create_access_token( subject=user.username, - user_claims={"roles": roles}, + user_claims={"role": user.role}, expires_time=expires_time, ) return AuthenticateResponse( username=user.username, - roles=roles, + role=user.role, expires_time=expires_time.seconds, access_token=access_token, ) @@ -76,7 +74,7 @@ async def get_user( current_user: UserBase = Depends(require_token), session=Depends(create_session), ): - if current_user.is_admin or current_user.is_manager: + if current_user.is_admin or current_user.is_publisher: if username == "me": username = current_user.username elif username == "me": @@ -158,7 +156,7 @@ async def change_password_admin( @router.post("/register") async def register_user( new_user: UserRegister, - current_username: str = Depends(require_app), + current_username: str = Depends(require_app), # FIXME session=Depends(create_session), ): query = UserQuery(session) diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index d6ad0a9..b204f38 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -7,7 +7,7 @@ class UserType(str, Enum): USER = "user" - MANAGER = "manager" + PUBLISHER = "publisher" APP = "app" ADMIN = "admin" @@ -21,8 +21,8 @@ def is_admin(self) -> bool: return self.role == UserType.ADMIN @property - def is_manager(self) -> bool: - return self.role == UserType.MANAGER + def is_publisher(self) -> bool: + return self.role == UserType.PUBLISHER @property def is_user(self) -> bool: diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index cab7903..fc4f225 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -9,7 +9,7 @@ from ..security.schemas import ( UserBase, ) -from ..security.depends import require_admin, require_token +from ..security.depends import require_admin, require_publisher, require_token from ..security.queries import UserQuery, UserException from .schemas import ( @@ -41,19 +41,16 @@ class SubscriptionSettings(BaseSettings): def create_application( application: ApplicationCreate, session: Session = Depends(create_session), - user: UserBase = Depends(require_token), + username: str = Depends(require_publisher), ): """ Create an application. """ applicationCreateWithOwner = ApplicationCreateWithOwner.copy( application, - update={"owner": user.username} + update={"owner": username} ) - if not user.is_manager: - raise HTTPException(401, "Only developers can create applications.") - try: return ApplicationQuery(session).create_application(applicationCreateWithOwner) except ApplicationException as e: diff --git a/tests/test_security.py b/tests/test_security.py index 4b640bd..921e8a1 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -76,8 +76,8 @@ def protected(Authorize: AuthJWT = Depends()): Authorize.jwt_required() user = Authorize.get_jwt_subject() - roles = Authorize.get_raw_jwt()["roles"] - return {"user": user, "roles": roles} + role = Authorize.get_raw_jwt()["role"] + return {"user": user, "role": role} @app.get("/admin") def admin(username=Depends(require_admin)): @@ -90,7 +90,7 @@ def admin(username=Depends(require_admin)): UserFactory(username="tester", role=UserType.USER) UserFactory(username="admin", role=UserType.ADMIN) - UserFactory(username="manager", role=UserType.MANAGER) + UserFactory(username="publisher", role=UserType.PUBLISHER) UserFactory(username="user", role=UserType.USER) UserFactory(username="app", role=UserType.APP) @@ -149,24 +149,24 @@ def test_token(self, client): def test_require_admin_when_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("manager", "password"), + headers=self._make_auth_header("admin", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) - assert response.status_code == 403 + assert response.status_code == 200 def test_require_admin_when_manager(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("publisher", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) - assert response.status_code == 200 + assert response.status_code == 403 def test_create_and_get_user(self, client): response = client.get( @@ -218,7 +218,7 @@ def test_get_users(self, client): response = client.get( "/user", headers={"Authorization": f"Bearer {token}"}, - json={"usernames": "admin,manager,user"}, + json={"usernames": "admin,publisher,user"}, ) assert response.status_code == 200 assert len(response.json()) == 3 diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 2f2924e..8fe63ba 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -10,7 +10,7 @@ from apihub.common.db_session import create_session from apihub.security.models import User from apihub.security.schemas import UserBase, UserType -from apihub.security.depends import require_user, require_admin, require_token +from apihub.security.depends import require_user, require_admin, require_token, require_publisher from apihub.subscription.depends import ( require_subscription_balance, ) @@ -97,7 +97,7 @@ def _require_user_token(): return UserBase(username="tester", role=UserType.USER) -def _require_manager_token(): +def _require_publisher_token(): return UserBase(username="tester", role=UserType.MANAGER) @@ -115,12 +115,16 @@ def _require_admin(): def _require_user(): return "user" + def _require_publisher(): + return "publisher" + app = FastAPI() app.include_router(router) app.dependency_overrides[create_session] = _create_session app.dependency_overrides[require_admin] = _require_admin app.dependency_overrides[require_user] = _require_user + app.dependency_overrides[require_publisher] = _require_publisher app.dependency_overrides[require_token] = _require_user_token @app.get("/api_balance/{application}") @@ -133,6 +137,10 @@ def api_function_2( UserFactory._meta.sqlalchemy_session_persistence = "commit" UserFactory(username="tester", role=UserType.USER) + UserFactory._meta.sqlalchemy_session = db_session + UserFactory._meta.sqlalchemy_session_persistence = "commit" + UserFactory(username="publisher", role=UserType.PUBLISHER) + ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" application = ApplicationFactory(name="test", url="/test") @@ -156,7 +164,6 @@ def api_function_2( class TestApplication: def test_create_application(self, client): - client.app.dependency_overrides[require_token] = _require_manager_token new_application = ApplicationCreate( name="app", url="/test", From 873bb6fcb9abf60e359b211aef107e7d5fc575f1 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Mon, 23 Jan 2023 14:26:42 +0300 Subject: [PATCH 05/17] alembic: read database url from env --- alembic/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic/env.py b/alembic/env.py index c9e8f11..f96cf10 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,7 +15,7 @@ if config.config_file_name is not None: fileConfig(config.config_file_name) -config.set_main_option('sqlalchemy.url', os.environ.get('DB_URI')) +config.set_main_option('sqlalchemy.url', os.environ.get('DB_URI').replace('%', '%%')) # add your model's MetaData object here # for 'autogenerate' support From 2979a51a343eb5900ea247b939d292ae38dc8ced Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Tue, 24 Jan 2023 16:11:59 +0300 Subject: [PATCH 06/17] refact: everything --- apihub/admin.py | 33 ++++- apihub/cli.py | 226 -------------------------------- apihub/client.py | 154 ---------------------- apihub/result.py | 18 ++- apihub/security/depends.py | 32 +++-- apihub/security/helpers.py | 12 ++ apihub/security/models.py | 31 ++++- apihub/security/queries.py | 72 ++++++----- apihub/security/router.py | 74 +++++------ apihub/security/schemas.py | 14 +- apihub/server.py | 23 ++-- apihub/subscription/depends.py | 56 ++++---- apihub/subscription/helpers.py | 17 +-- apihub/subscription/models.py | 64 +++++---- apihub/subscription/queries.py | 200 ++++++++++++++++++---------- apihub/subscription/router.py | 92 +++++++------ apihub/subscription/schemas.py | 25 ++-- tests/test_activity.py | 106 --------------- tests/test_result.py | 38 +++--- tests/test_security.py | 99 +++++++------- tests/test_server.py | 9 +- tests/test_subscription.py | 229 +++++++++++++-------------------- 22 files changed, 618 insertions(+), 1006 deletions(-) delete mode 100644 apihub/cli.py delete mode 100644 apihub/client.py delete mode 100644 tests/test_activity.py diff --git a/apihub/admin.py b/apihub/admin.py index 3ccd55e..7bf86b4 100644 --- a/apihub/admin.py +++ b/apihub/admin.py @@ -6,17 +6,19 @@ from apihub.common.redis_session import redis_conn from apihub.security.schemas import UserCreate, UserType from apihub.security.queries import UserQuery +from apihub.security.models import User, Profile from apihub.subscription.queries import SubscriptionQuery +from apihub.subscription.models import Application, Subscription, Pricing class SuperUser(BaseSettings): - username: str + name: str password: str email: str def as_usercreate(self): return UserCreate( - username=self.username, + name=self.name, password=self.password, email=self.email, role=UserType.ADMIN, @@ -30,7 +32,7 @@ def init(): with db_context() as session: user = SuperUser().as_usercreate() UserQuery(session).create_user(user) - sys.stderr.write(f"Admin {user.username} is created!") + print(f"Admin {user.name} is created!", file=sys.stderr) def create_all_statements(): @@ -47,4 +49,27 @@ def metadata_dump(sql, *multiparams, **params): def deinit(): Base.metadata.bind = DB_ENGINE Base.metadata.drop_all() - sys.stderr.write("deinit is done!") + print("deinit is done!", file=sys.stderr) + + +def load_data(filename): + import yaml + data = yaml.load(open(filename, 'r', encoding='utf-8'),) + for name, items in data.items(): + if name == 'user': + with db_context() as session: + query = UserQuery(session) + for item in items: + user_id = query.create_user( + UserCreate( + name=item.name, + email=item.email, + password=item.password, + role=item.role, + ) + ) + if user_id: + elif name == 'application': + for item in items: + else: + print(f"model {name} not supported", file=sys.stderr) diff --git a/apihub/cli.py b/apihub/cli.py deleted file mode 100644 index 5471d24..0000000 --- a/apihub/cli.py +++ /dev/null @@ -1,226 +0,0 @@ -import sys -import json -import time -import asyncio -import fileinput -from datetime import datetime, timedelta - -try: - import typer -except ImportError: - sys.stderr.write('Please install "apihub[cli] to install cli option"') - sys.exit(1) - -from dotenv import load_dotenv - -load_dotenv() - -from apihub.client import Client -from .subscription.schemas import SubscriptionIn - - -cli = typer.Typer() - - -def try_load_state(username): - client = Client.load_state(filename=f"{username}.apihub") - if client.token is None: - typer.echo("You need to login first") - sys.exit(1) - return client - - -@cli.command() -def login(username: str, password: str, endpoint: str = "http://localhost:5000"): - client = Client({"endpoint": endpoint}) - client.authenticate(username=username, password=password) - client.save_state(filename=f"{username}.apihub") - - -@cli.command() -def refresh_token( - application: str, - admin: str = "", - manager: str = "", - username: str = "", - expires: int = 1, -): - if admin: - admin_client = Client.load_state(filename=f"{admin}.apihub") - elif manager: - admin_client = Client.load_state(filename=f"{manager}.apihub") - else: - admin_client = None - client = Client.load_state(filename=f"{username}.apihub") - - if admin_client: - if admin_client.token is None: - cli.echo("You need to login first") - sys.exit(1) - token = admin_client.refresh_application_token(application, username, expires) - client.applications[application] = token - else: - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - client.refresh_application_token(application, username, expires) - client.save_state(filename=f"{username}.apihub") - - -@cli.command() -def list_users(role: str, admin: str = ""): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - users = client.get_users_by_role(role) - for user in users: - print(user) - - -@cli.command() -def create_user( - username: str, password: str, email: str, admin: str = "", role: str = "user" -): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - client.create_user( - { - "username": username, - "password": password, - "role": role, - "email": email, - } - ) - - -@cli.command() -def create_subscription( - username: str, - application: str, - admin: str = "", - days: int = 0, - recurring: bool = False, -): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - expires_at = datetime.now() + timedelta(days=days) if days else None - client.create_subscription( - SubscriptionIn( - username=username, - application=application, - starts_at=datetime.now(), - expires_at=expires_at, - recurring=recurring, - ) - ) - - -@cli.command() -def post_request(application: str, username: str): - client = try_load_state(username) - - data = json.loads(sys.stdin.read()) - response = client.async_request(application, data) - print(response) - return - - MARKER = "# Please write body in json above" - message = cli.edit("\n\n" + MARKER) - if message is not None: - data = json.loads(message.split(MARKER, 1)[0]) - response = client.async_request(application, data) - print(response) - else: - cli.echo("input cannot be empty") - sys.exit(1) - - -@cli.command() -def fetch_result(application: str, key: str, username: str = ""): - client = try_load_state(username) - - response = client.async_result(application, key) - print(response) - - -@cli.command() -def batch_request( - username: str, application: str, filename: str = "", parallel: int = 10 -): - jobs = [] - client = try_load_state(username) - files = [filename] if filename else [] - for i, line in enumerate(fileinput.input(files=files)): - data = json.loads(line) - response = client.async_request(application, data) - if "success" not in response or not response["success"]: - print(f"Failed on input line {i}, request failed", file=sys.stderr) - print(response) - sys.exit(1) - - key = response["key"] - if len(jobs) < parallel: - jobs.append((i, data, key)) - else: - unfinished_jobs = [] - for i, data, key in jobs: - response = client.async_result(application, key) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - unfinished_jobs.append((i, data, key)) - - jobs = unfinished_jobs - - time.sleep(1) - - while len(jobs) > 0: - unfinished_jobs = [] - for i, data, key in jobs: - response = client.async_result(application, key) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - unfinished_jobs.append((i, data, key)) - - jobs = unfinished_jobs - - -@cli.command() -def batch_sync_request( - username: str, application: str, filename: str = "", parallel: int = 10 -): - queue = asyncio.Queue() - - async def request_routine(client, application, data): - response = await client.sync_request(application, data) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - - client = try_load_state(username) - files = [filename] if filename else [] - for i, line in enumerate(fileinput.input(files=files)): - data = json.loads(line) - - queue.put(request_routine(client, application, data)) - time.sleep(1) - - -if __name__ == "__main__": - cli() diff --git a/apihub/client.py b/apihub/client.py deleted file mode 100644 index 55257ba..0000000 --- a/apihub/client.py +++ /dev/null @@ -1,154 +0,0 @@ -import json -from typing import Optional, Any, Dict - -import requests -from pydantic import BaseSettings - -from apihub.security.router import UserCreate -from .subscription.schemas import SubscriptionIn - - -class ClientSettings(BaseSettings): - endpoint: str = "http://localhost" - token: Optional[str] = None - - -class Client: - def __init__(self, settings: Dict[str, Any]) -> None: - if settings is not None: - self.settings = ClientSettings.parse_obj(settings) - else: - self.settings = ClientSettings() - self.token: Optional[str] = None - self.applications: Dict[str, Any] = {} - - def _make_url(self, path: str) -> str: - return f"{self.settings.endpoint}/{path}" - - def save_state(self, filename="~/.apihubrc") -> None: - json.dump( - { - "settings": self.settings.dict(), - "token": self.token, - "applications": self.applications, - }, - open(filename, "w"), - ) - - @staticmethod - def load_state(filename="~/.apihubrc") -> "Client": - state = json.load(open(filename)) - client = Client(state["settings"]) - client.token = state["token"] - client.applications = state["applications"] - return client - - def authenticate(self, username: str, password: str) -> None: - # TODO exceptions - response = requests.get( - self._make_url("_authenticate"), - auth=(username, password), - ) - if response.status_code == 200: - print(response.json()) - self.token = response.json()["access_token"] - else: - print(response.json()) - - def create_user(self, user): - username = user["username"] - response = requests.post( - self._make_url(f"user/{username}"), - headers={"Authorization": f"Bearer {self.token}"}, - json=UserCreate.parse_obj(user).dict(), - ) - if response.status_code == 200: - return True - - def get_users_by_role(self, role): - response = requests.get( - self._make_url(f"users/{role}"), - headers={"Authorization": f"Bearer {self.token}"}, - ) - if response.status_code == 200: - return response.json() - - def create_subscription(self, subscription: SubscriptionIn): - response = requests.post( - self._make_url("subscription"), - headers={"Authorization": f"Bearer {self.token}"}, - data=subscription.json(), - ) - if response.status_code == 200: - return True - else: - raise Exception(response.json()) - - def refresh_application_token( - self, application: str, username: str, expires_days: int - ) -> None: - # TODO exceptions - params = {} - if username: - params["username"] = username - if expires_days: - params["expires_days"] = expires_days - - response = requests.get( - self._make_url(f"token/{application}"), - headers={"Authorization": f"Bearer {self.token}"}, - params=params, - ) - if response.status_code == 200: - print(response.text) - print(response.json()) - self.applications[application] = response.json()["token"] - return response.json()["token"] - else: - print(response.text) - print(response.json()) - - def _check_token_for(self, application: str) -> bool: - return application in self.applications - - def async_request(self, application: str, data: dict): - response = requests.post( - self._make_url(f"async/{application}"), - headers={ - "Authorization": f"Bearer {self.applications[application]}", - }, - json=data, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() - - def async_result(self, application: str, key: str): - # TODO wait and timeout - response = requests.get( - self._make_url(f"async/{application}"), - params={ - "key": key, - }, - headers={ - "Authorization": f"Bearer {self.token}", - }, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() - - def sync_request(self, application: str, data: dict): - response = requests.post( - self._make_url(f"sync/{application}"), - headers={ - "Authorization": f"Bearer {self.applications[application]}", - }, - json=data, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() diff --git a/apihub/result.py b/apihub/result.py index 2b170b2..11b9eb9 100644 --- a/apihub/result.py +++ b/apihub/result.py @@ -4,8 +4,6 @@ from pipeline import ProcessorSettings, Processor, Command, CommandActions, Definition -from .activity.queries import ActivityQuery -from .activity.schemas import ActivityStatus from .common.db_session import create_session from .utils import Result, RedisSettings, DefinitionManager from . import __worker__, __version__ @@ -61,10 +59,10 @@ def process_command(self, command: Command) -> None: def process(self, message_content, message_id): self.logger.info("Processing MESSAGE") result = Result.parse_obj(message_content) - if result.status == ActivityStatus.PROCESSED: - result.result = { - k: message_content.get(k) for k in self.message.logs[-1].updated - } + # if result.status == ActivityStatus.PROCESSED: + # result.result = { + # k: message_content.get(k) for k in self.message.logs[-1].updated + # } self.api_counter.labels(api=result.api, user=result.user, status=result.status) @@ -74,10 +72,10 @@ def process(self, message_content, message_id): r = result.json() self.redis.set(message_id, r, ex=86400) - if result.status == ActivityStatus.PROCESSED: - ActivityQuery(self.session).update_activity( - message_id, **{"status": ActivityStatus.PROCESSED} - ) + # if result.status == ActivityStatus.PROCESSED: + # ActivityQuery(self.session).update_activity( + # message_id, **{"status": ActivityStatus.PROCESSED} + # ) return None diff --git a/apihub/security/depends.py b/apihub/security/depends.py index fd59b38..feb4e81 100644 --- a/apihub/security/depends.py +++ b/apihub/security/depends.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, Depends, Request from fastapi_jwt_auth import AuthJWT -from .schemas import UserBase +from .schemas import UserBaseWithId HTTP_429_TOO_MANY_REQUESTS = 429 @@ -52,10 +52,18 @@ def __init__(self, role: Optional[str] = None, roles: List[str] = list()): def __call__(self, Authorize: AuthJWT = Depends()): Authorize.jwt_required() - role = Authorize.get_raw_jwt().get("role", "") + claims = Authorize.get_raw_jwt() + role = claims.get("role", "") if role in self.roles: - username = Authorize.get_jwt_subject() - return username + name = claims.get("name", "") + email = Authorize.get_jwt_subject() + user_id = claims.get("id", "") + return UserBaseWithId( + id=user_id, + name=name, + email=email, + role=role, + ) raise HTTPException( HTTP_403_FORBIDDEN, @@ -63,12 +71,17 @@ def __call__(self, Authorize: AuthJWT = Depends()): ) -def require_token(Authorize: AuthJWT = Depends()) -> UserBase: +def require_token(Authorize: AuthJWT = Depends()) -> UserBaseWithId: Authorize.jwt_required() - role = Authorize.get_raw_jwt()["role"] - username = Authorize.get_jwt_subject() - return UserBase( - username=username, + claims = Authorize.get_raw_jwt() + role = claims.get("role", "") + name = claims.get("name", "") + email = Authorize.get_jwt_subject() + user_id = claims.get("id", "") + return UserBaseWithId( + id=user_id, + name=name, + email=email, role=role, ) @@ -78,3 +91,4 @@ def require_token(Authorize: AuthJWT = Depends()) -> UserBase: require_user = UserOfRole(role="user") require_app = UserOfRole(role="app") require_publisher_or_admin = UserOfRole(roles=["admin", "publisher"]) +require_logged_in = UserOfRole(roles=["admin", "publisher", "user", "app"]) diff --git a/apihub/security/helpers.py b/apihub/security/helpers.py index 8ef1863..887e7ed 100644 --- a/apihub/security/helpers.py +++ b/apihub/security/helpers.py @@ -1,6 +1,8 @@ import os +import datetime import hashlib from base64 import b64encode, b64decode +from fastapi_jwt_auth import AuthJWT def hash_password(password, salt=None): @@ -23,3 +25,13 @@ def hash_password(password, salt=None): dklen=64, ).hex() return salt, hashed_password + + +def make_token(user, expires_time): + Authorize = AuthJWT() + access_token = Authorize.create_access_token( + subject=user.email, + user_claims={"role": user.role, "name": user.name, "id": user.id}, + expires_time=expires_time, + ) + return access_token \ No newline at end of file diff --git a/apihub/security/models.py b/apihub/security/models.py index 2a09716..e6ead66 100644 --- a/apihub/security/models.py +++ b/apihub/security/models.py @@ -8,6 +8,7 @@ String, DateTime, Enum, + ForeignKey, ) from .schemas import UserType @@ -22,14 +23,36 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True, index=True, nullable=False) - email = Column(String, index=True) + email = Column(String, index=True, unique=True, nullable=False) + name = Column(String, nullable=False) salt = Column(String) hashed_password = Column(String) role = Column(Enum(UserType), default=UserType.USER) is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.now()) - subscriptions = relationship("Subscription", back_populates="user") + + profile = relationship("Profile", uselist=False, back_populates="user") + + def __str__(self): + return f"{self.email} || {self.role} || {self.is_active}" + + + + +class Profile(Base): + """ + This class is used to store user profile data. + """ + __tablename__ = "profiles" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String) + bio = Column(String) + url = Column(String) + avatar = Column(String) + + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", cascade = "all,delete", back_populates="profile") def __str__(self): - return f"{self.username} || {self.role}" + return f"{self.user_id} || {self.name}" \ No newline at end of file diff --git a/apihub/security/queries.py b/apihub/security/queries.py index 1b99dc8..c3156f4 100644 --- a/apihub/security/queries.py +++ b/apihub/security/queries.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.orm.exc import NoResultFound @@ -28,26 +28,30 @@ def get_user_by_id(self, user_id: int) -> UserSession: :param user_id: integer id :return: UserSession object. """ - user = self.get_query().filter(User.id == user_id).one() - return UserSession( - id=user.id, - username=user.username, - role=user.role, - salt=user.salt, - hashed_password=user.hashed_password, - ) + try: + user = self.get_query().filter(User.id == user_id).one() + return UserSession( + id=user.id, + name=user.name, + email=user.email, + role=user.role, + salt=user.salt, + hashed_password=user.hashed_password, + ) + except NoResultFound: + raise UserException("User not found.") - def get_user_by_username_and_password( - self, username: str, password: str + def get_user_by_email_and_password( + self, email: str, password: str ) -> UserSession: """ - Get user by username and password. - :param username: str + Get user by email and password. + :param email: str :param password: str :return: UserSession object. """ try: - user = self.get_query().filter(User.username == username).one() + user = self.get_query().filter(User.email == email).one() except NoResultFound: raise UserException @@ -55,47 +59,50 @@ def get_user_by_username_and_password( if user.hashed_password == hashed_password: return UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, ) raise UserException - def get_user_by_username(self, username: str) -> UserSession: + def get_user_by_email(self, email: str) -> UserSession: """ - Get user by username. - :param username: str + Get user by email. + :param email: str :return: UserSession object. """ try: - user = self.get_query().filter(User.username == username).one() + user = self.get_query().filter(User.email == email).one() except NoResultFound: raise UserException return UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, ) - def get_users_by_usernames(self, usernames: List[str]) -> List[UserSession]: + def get_users_by_emails(self, emails: List[str]) -> List[UserSession]: """ - Get users by usernames. - :param usernames: str + Get users by emails. + :param emails: str :return: list of UserSession object. """ try: - users = self.get_query().filter(User.username.in_(usernames)) + users = self.get_query().filter(User.email.in_(emails)) except NoResultFound: raise UserException return [ UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, @@ -116,13 +123,14 @@ def get_users_by_role(self, role) -> List[UserBase]: return [ UserBase( - username=user.username, + name=user.name, + email=user.email, role=user.role, ) for user in users ] - def create_user(self, user: UserCreate) -> bool: + def create_user(self, user: UserCreate) -> Optional[int]: """ Create a new user. :param user: UserCreate object. @@ -136,19 +144,19 @@ def create_user(self, user: UserCreate) -> bool: self.session.rollback() return False - return True + return db_user.id - def change_password(self, username: str, password: str) -> bool: + def change_password(self, email: str, password: str) -> bool: """ Change password for a user. - :param username: str + :param email: str :param password: str :return: boolean. """ try: - user_in_db = self.get_query().filter(User.username == username).one() + user_in_db = self.get_query().filter(User.email == email).one() except NoResultFound: - raise UserException + raise UserException(f"User {email} not found") _, hashed_password = hash_password(password, salt=user_in_db.salt) user_in_db.hashed_password = hashed_password diff --git a/apihub/security/router.py b/apihub/security/router.py index ea1e634..481c77c 100644 --- a/apihub/security/router.py +++ b/apihub/security/router.py @@ -9,6 +9,7 @@ from .schemas import UserCreate, UserBase, UserRegister, UserType from .queries import UserQuery, UserException from .depends import require_token, require_admin, require_app +from .helpers import make_token security = HTTPBasic() @@ -28,7 +29,7 @@ def get_config(): class AuthenticateResponse(BaseModel): - username: str + email: str role: str access_token: str expires_time: int @@ -42,8 +43,8 @@ async def _authenticate( ): query = UserQuery(session) try: - user = query.get_user_by_username_and_password( - username=credentials.username, + user = query.get_user_by_email_and_password( + email=credentials.username, password=credentials.password, ) except UserException: @@ -53,55 +54,44 @@ async def _authenticate( if expires_days > SecuritySettings().security_token_expires_time: expires_days = SecuritySettings().security_token_expires_time - Authorize = AuthJWT() expires_time = datetime.timedelta(days=expires_days) - access_token = Authorize.create_access_token( - subject=user.username, - user_claims={"role": user.role}, - expires_time=expires_time, - ) + + Authorize = AuthJWT() + access_token = make_token(user, expires_time) return AuthenticateResponse( - username=user.username, + email=user.email, role=user.role, expires_time=expires_time.seconds, access_token=access_token, ) -@router.get("/user/{username}") +@router.get("/user") async def get_user( - username: str, - current_user: UserBase = Depends(require_token), + user: UserBase = Depends(require_token), session=Depends(create_session), ): - if current_user.is_admin or current_user.is_publisher: - if username == "me": - username = current_user.username - elif username == "me": - username = current_user.username - else: - raise HTTPException(HTTP_403_FORBIDDEN, "You have no permission") - query = UserQuery(session) - user = query.get_user_by_username(username=username) + user = query.get_user_by_email(email=user.email) return UserBase( - username=user.username, + name=user.name, + email=user.email, role=user.role, ) class GetUserAdminIn(BaseModel): - usernames: str + emails: str @router.get("/user") async def get_user_admin( - usernames: GetUserAdminIn, - current_username: str = Depends(require_admin), + group: GetUserAdminIn, + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - users = query.get_users_by_usernames(usernames=usernames.usernames.split(",")) + users = query.get_users_by_emails(emails=group.emails.split(",")) return [UserBase(**user.dict()) for user in users] @@ -112,21 +102,21 @@ class ChangePasswordIn(BaseModel): @router.post("/user/_password") async def change_password( password: ChangePasswordIn, - current_user: UserBase = Depends(require_token), + user: UserBase = Depends(require_token), session=Depends(create_session), ): query = UserQuery(session) - query.change_password(username=current_user.username, password=password.password) + query.change_password(email=user.email, password=password.password) @router.post("/user") async def create_user( - new_user: UserCreate, - current_username: str = Depends(require_admin), + user: UserCreate, + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - query.create_user(new_user) + query.create_user(user) # TODO handling results return {} @@ -134,7 +124,7 @@ async def create_user( @router.get("/users/{role}") async def list_users( role: str, - current_username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) @@ -142,29 +132,29 @@ async def list_users( return users -@router.post("/user/{username}/_password") +@router.post("/user/{email}/_password") async def change_password_admin( - username: str, + email: str, password: ChangePasswordIn, - current_username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - query.change_password(username=username, password=password.password) + query.change_password(email=email, password=password.password) @router.post("/register") async def register_user( - new_user: UserRegister, - current_username: str = Depends(require_app), # FIXME + user: UserRegister, + app: str = Depends(require_app), # FIXME session=Depends(create_session), ): query = UserQuery(session) query.create_user( UserCreate( - username=new_user.username, - email=new_user.email, - password=new_user.password, + name=user.name, + email=user.email, + password=user.password, role=UserType.USER, ) ) diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index b204f38..d68d82a 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -13,7 +13,8 @@ class UserType(str, Enum): class UserBase(BaseModel): - username: str + email: str + name: str role: UserType @property @@ -33,14 +34,17 @@ def is_app(self) -> bool: return self.role == UserType.APP +class UserBaseWithId(UserBase): + id: int + + class UserRegister(BaseModel): - username: str + name: str email: str password: str class UserCreate(UserBase): - email: str password: str def make_user(self): @@ -50,7 +54,7 @@ def make_user(self): """ salt, hashed_password = hash_password(self.password) return UserCreateHashed( - username=self.username, + name=self.name, email=self.email, salt=salt, hashed_password=hashed_password, @@ -70,4 +74,4 @@ class UserSession(UserBase): class User(UserSession): - pass + pass \ No newline at end of file diff --git a/apihub/server.py b/apihub/server.py index 6249a53..d2afe0c 100644 --- a/apihub/server.py +++ b/apihub/server.py @@ -18,9 +18,8 @@ from .activity.queries import ActivityQuery from .security.depends import RateLimiter, RateLimits, require_user from .security.router import router as security_router -from .subscription.depends import require_subscription +from .subscription.depends import require_subscription, SubscriptionResponse from .subscription.router import router as subscription_router -from .subscription.schemas import SubscriptionBase from .utils import ( State, make_topic, @@ -126,7 +125,7 @@ async def define_service( return {"define": f"application {application}"} -async def make_request(username: str, application: str, request: Request): +async def make_request(email: str, application: str, request: Request): """Make request to application""" key = make_key() @@ -149,7 +148,7 @@ async def make_request(username: str, application: str, request: Request): # inject user information info = Result( - user=username, + user=email, api=application, status=ActivityStatus.ACCEPTED, ) @@ -164,12 +163,12 @@ async def make_request(username: str, application: str, request: Request): return key -def fetch_result(username: str, application: str, key: str): +def fetch_result(email: str, application: str, key: str): """fetch result""" result = get_redis().get(key) if result is None: operation_counter.labels( - api=application, user=username, operation="result_not_found" + api=application, user=email, operation="result_not_found" ).inc() raise HTTPException( status_code=404, @@ -185,7 +184,7 @@ def fetch_result(username: str, application: str, key: str): ) elif result.status != ActivityStatus.PROCESSED: operation_counter.labels( - api=application, user=username, operation="error" + api=application, user=email, operation="error" ).inc() # FIXME change status code raise HTTPException( @@ -205,18 +204,18 @@ def fetch_result(username: str, application: str, key: str): async def async_service( request: Request, # background_tasks: BackgroundTasks, - subscription: SubscriptionBase = Depends(require_subscription), + subscription: SubscriptionResponse = Depends(require_subscription), ): """generic handler for async api.""" - username = subscription.username + email = subscription.email tier = subscription.tier application = subscription.application - operation_counter.labels(api=application, user=username, operation="received").inc() + operation_counter.labels(api=application, user=email, operation="received").inc() - key = await make_request(username, application, request) + key = await make_request(email, application, request) - operation_counter.labels(api=application, user=username, operation="accepted").inc() + operation_counter.labels(api=application, user=email, operation="accepted").inc() # activity = ActivityCreate( # request=f"/async/{application}", diff --git a/apihub/subscription/depends.py b/apihub/subscription/depends.py index d3966d9..e604037 100644 --- a/apihub/subscription/depends.py +++ b/apihub/subscription/depends.py @@ -1,3 +1,4 @@ +from pydantic import BaseModel from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from redis import Redis @@ -5,7 +6,6 @@ from ..common.db_session import create_session from ..common.redis_session import redis_conn -from .schemas import SubscriptionBase from .queries import SubscriptionQuery from .helpers import make_key, BALANCE_KEYS @@ -14,9 +14,18 @@ HTTP_429_QUOTA = 429 +class SubscriptionResponse(BaseModel): + user_id: int + subscription_id: int + application_id: int + email: str + tier: str + application: str + + def require_subscription( application: str, Authorize: AuthJWT = Depends() -) -> SubscriptionBase: +) -> SubscriptionResponse: """ This function is used to check if the user has a valid subscription token. :param application: str @@ -24,43 +33,46 @@ def require_subscription( :return: SubscriptionBase object. """ Authorize.jwt_required() - username = Authorize.get_jwt_subject() - + email = Authorize.get_jwt_subject() claims = Authorize.get_raw_jwt() - subscription_claim = claims.get("subscription") - tier_claim = claims.get("tier") - if subscription_claim != application: + subscription = claims.get("subscription") + tier = claims.get("tier") + if subscription != application: raise HTTPException( HTTP_403_FORBIDDEN, "The API key doesn't have permission to perform the request", ) - return SubscriptionBase( - username=username, tier=tier_claim, application=subscription_claim + user_id = claims.get("use_id", -1) + subscription_id = claims.get("subscription_id", -1) + application_id = claims.get("application_id", -1) + return SubscriptionResponse( + user_id=user_id, subscription_id=subscription_id, + email=email, tier=tier, application=subscription, + application_id=application_id, ) def require_subscription_balance( - subscription: SubscriptionBase = Depends(require_subscription), + subscription: SubscriptionResponse = Depends(require_subscription), redis: Redis = Depends(redis_conn), session=Depends(create_session), -) -> str: +) -> SubscriptionResponse: """ This function is used to check if the user has enough balance to perform. :param subscription: str :param redis: Redis object. :param session: Session object. - :return: username str. + :return: email str. """ - username = subscription.username - tier = subscription.tier - application = subscription.application - - key = make_key(username, application, tier) + key = make_key(subscription) balance = redis.decr(key) + + + print("balance", balance) - if balance == -1: - subscription = SubscriptionQuery(session).get_active_subscription( - username, application + if balance is None or balance == -1: + subscription = SubscriptionQuery(session).get_subscription( + subscription.subscription_id ) balance = subscription.credit - subscription.balance - 1 if balance > 0: @@ -69,7 +81,7 @@ def require_subscription_balance( if balance <= 0: SubscriptionQuery(session).update_balance_in_subscription( - username, application, tier, redis + subscription, redis ) if balance < 0: @@ -78,4 +90,4 @@ def require_subscription_balance( "You have used up all credit for this API", ) - return username + return subscription \ No newline at end of file diff --git a/apihub/subscription/helpers.py b/apihub/subscription/helpers.py index f39ac29..17960c7 100644 --- a/apihub/subscription/helpers.py +++ b/apihub/subscription/helpers.py @@ -6,30 +6,23 @@ BALANCE_KEYS = "balance:keys" -def make_key(username: str, application: str, tier: str) -> str: - """ - Make key for redis. - :param username: str - :param application: str - :param tier: str - :return: str. - """ - return f"balance:{username}:{application}:{tier}" +def make_key(subscription) -> str: + return f"balance:{subscription.user_id}:{subscription.application_id}:{subscription.tier}" @contextmanager def get_and_reset_balance_in_cache( - username: str, application: str, tier: str, redis: Redis + subscription, redis: Redis ) -> None: """ Get balance from cache and delete it. - :param username: str + :param email: str :param application: str :param tier: str :param redis: Redis object. :return: None """ - key = make_key(username, application, tier) + key = make_key(subscription) balance = redis.get(key) yield int(balance) diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index 95d62c2..0a7b9d3 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import relationship from ..common.db_session import Base -from .schemas import SubscriptionTier, ApplicationCreate, SubscriptionPricingCreate +from .schemas import SubscriptionTier, ApplicationCreate, PricingCreate class Application(Base): @@ -21,18 +21,20 @@ class Application(Base): This class is used to store application data. """ - __tablename__ = "application" + __tablename__ = "applications" id = Column(Integer, primary_key=True, index=True) name = Column(String, unique=True, index=True, nullable=False) url = Column(String) description = Column(String) + is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.now()) - owner = Column(String, ForeignKey("users.username")) + owner_id = Column(Integer, ForeignKey("users.id")) - subscriptions = relationship("Subscription", backref="app") - subscriptions_pricing = relationship("SubscriptionPricing", backref="app") + owner = relationship("User") + subscriptions = relationship("Subscription", back_populates="application") + pricings = relationship("Pricing", back_populates="application") def __str__(self): return f"{self.name} || {self.url}" @@ -42,32 +44,41 @@ def to_schema(self, with_pricing=False) -> ApplicationCreate: name=self.name, url=self.url, description=self.description, - pricing=[ - SubscriptionPricingCreate( + pricings=[ + PricingCreate( tier=pricing.tier, price=pricing.price, credit=pricing.credit, application=self.name, ) - for pricing in self.subscriptions_pricing + for pricing in self.pricings ] if with_pricing else [], ) -class SubscriptionPricing(Base): +class Pricing(Base): """ This class is used to store subscription pricing data. """ - __tablename__ = "subscription_pricing" - __table_args__ = (UniqueConstraint("application", "tier", name="application_tier"),) + __tablename__ = "pricings" + __table_args__ = ( + UniqueConstraint( + "application_id", "tier", name="application_tier_constraint" + ), + ) id = Column(Integer, primary_key=True, index=True) tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) price = Column(Integer) credit = Column(Integer) + is_active = Column(Boolean, default=True) + + created_at = Column(DateTime, default=datetime.now()) + + application_id = Column(Integer, ForeignKey("applications.id"), nullable=False) + application = relationship("Application", uselist=False, back_populates="pricings") - application = Column(String, ForeignKey("application.name"), nullable=False) def __str__(self): - return f"{self.application} || {self.tier} || {self.price}" + return f"{self.application_id} || {self.tier} || {self.price}" class Subscription(Base): @@ -76,28 +87,33 @@ class Subscription(Base): """ __tablename__ = "subscriptions" - __table_args__ = ( - UniqueConstraint( - "application", "tier", "username", name="application_tier_username" - ), - ) + # __table_args__ = ( + # UniqueConstraint( + # "application_id", "tier", "user_id", name="application_tier_user_constraint" + # ), + # ) id = Column(Integer, primary_key=True, index=True) tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) - active = Column(Boolean, default=True) + is_active = Column(Boolean, default=True) credit = Column(Integer, default=0) balance = Column(Integer, default=0) starts_at = Column(DateTime, default=datetime.now()) expires_at = Column(DateTime) recurring = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.now()) - created_by = Column(String) + # created_by = Column(Integer, ForeignKey("users.id")) notes = Column(String) - username = Column(String, ForeignKey("users.username"), nullable=False) - user = relationship("User", back_populates="subscriptions") + owner_id = Column(Integer, ForeignKey("users.id")) + owner = relationship("User") + + application_id = Column(Integer, ForeignKey("applications.id"), nullable=False) + application = relationship("Application", uselist=False, back_populates="subscriptions") + + pricing_id = Column(Integer, ForeignKey("pricings.id"), nullable=False) + pricing = relationship("Pricing", uselist=False) - application = Column(String, ForeignKey("application.name"), nullable=False) def __str__(self): - return f"{self.application} || {self.tier} || {self.username}" + return f"{self.application} || {self.tier} || {self.email}" diff --git a/apihub/subscription/queries.py b/apihub/subscription/queries.py index cbfdd4d..b3fe3c7 100644 --- a/apihub/subscription/queries.py +++ b/apihub/subscription/queries.py @@ -8,13 +8,13 @@ from sqlalchemy.orm import Query from ..common.queries import BaseQuery -from .models import Subscription, Application, SubscriptionPricing +from .models import Subscription, Application, Pricing from .schemas import ( SubscriptionCreate, SubscriptionDetails, ApplicationCreate, ApplicationCreateWithOwner, - SubscriptionPricingBase, + PricingBase, ) from .helpers import get_and_reset_balance_in_cache @@ -23,7 +23,7 @@ class ApplicationException(Exception): pass -class SubscriptionPricingException(Exception): +class PricingException(Exception): pass @@ -39,28 +39,26 @@ def get_query(self) -> Query: """ return self.session.query(Application) - def create_application(self, application: ApplicationCreateWithOwner): + def create_application(self, application: ApplicationCreateWithOwner) -> ApplicationCreateWithOwner: """ Create an application. :param application: Application details. :return: Application object. """ - pricing_list = [] - for pricing in application.pricing: - pricing_list.append( - SubscriptionPricing( + pricings = [] + for pricing in application.pricings: + pricings.append( + Pricing( tier=pricing.tier, price=pricing.price, credit=pricing.credit, - application=application.name, ) ) application_object = Application( name=application.name, url=application.url, description=application.description, - owner=application.owner, - subscriptions_pricing=pricing_list, + pricings=pricings, ) try: self.session.add(application_object) @@ -70,7 +68,19 @@ def create_application(self, application: ApplicationCreateWithOwner): self.session.rollback() raise ApplicationException(f"Error creating application: {e}") - def get_application(self, name: str) -> ApplicationCreate: + def get_application(self, application_id: int) -> ApplicationCreate: + """ + Get application by name. + :param name: Application name. + :return: application object. + """ + try: + application = self.get_query().filter(Application.id == application_id).one() + return application.to_schema(with_pricing=True) + except NoResultFound: + raise ApplicationException(f"Application {application_id} not found.") + + def get_application_by_name(self, name: str) -> ApplicationCreate: """ Get application by name. :param name: Application name. @@ -82,7 +92,7 @@ def get_application(self, name: str) -> ApplicationCreate: except NoResultFound: raise ApplicationException(f"Application {name} not found.") - def get_applications(self, username=None) -> List[ApplicationCreate]: + def get_applications(self, email=None) -> List[ApplicationCreate]: """ List applications. :return: List of applications. @@ -91,26 +101,26 @@ def get_applications(self, username=None) -> List[ApplicationCreate]: return list(applications) -class SubscriptionPricingQuery(BaseQuery): +class PricingQuery(BaseQuery): def get_query(self) -> Query: """ Get query object. :return: Query object. """ - return self.session.query(SubscriptionPricing) + return self.session.query(Pricing) def create_subscription_pricing( self, tier: str, application: str, price: int, credit: int - ) -> SubscriptionPricing: + ) -> Pricing: """ Create subscription pricing. :param tier: Subscription tier. :param application: Application name. :param price: Price. :param credit: Credit. - :return: SubscriptionPricing object. + :return: Pricing object. """ - subscription_pricing = SubscriptionPricing( + subscription_pricing = Pricing( tier=tier, application=application, price=price, @@ -122,31 +132,31 @@ def create_subscription_pricing( return subscription_pricing except Exception as e: self.session.rollback() - raise SubscriptionPricingException( + raise PricingException( f"Error creating subscription pricing: {e}" ) def get_subscription_pricing( - self, application: str, tier: str - ) -> SubscriptionPricing: + self, application_id: int, tier: str + ) -> Pricing: """ Get subscription pricing by application name and tier. :param application: Application name. :param tier: Subscription tier. - :return: SubscriptionPricing object. + :return: Pricing object. """ try: return ( self.get_query() .filter( - SubscriptionPricing.application == application, - SubscriptionPricing.tier == tier, + Pricing.application_id == application_id, + Pricing.tier == tier, ) .one() ) except NoResultFound: - raise SubscriptionPricingException( - f"Subscription pricing for application {application} and tier {tier} not found." + raise PricingException( + f"Subscription pricing for application {application_id} and tier {tier} not found." ) @@ -168,7 +178,7 @@ def create_subscription(self, subscription_create: SubscriptionCreate): found_existing_subscription = True try: self.get_active_subscription( - subscription_create.username, subscription_create.application + subscription_create.owner_id, subscription_create.application_id ) except SubscriptionException: found_existing_subscription = False @@ -178,22 +188,21 @@ def create_subscription(self, subscription_create: SubscriptionCreate): "Found existing subscription, please delete it before create new subscription" ) - subscription_pricing = SubscriptionPricingQuery( + subscription_pricing = PricingQuery( self.session ).get_subscription_pricing( - subscription_create.application, subscription_create.tier + subscription_create.application_id, subscription_create.tier ) new_subscription = Subscription( - username=subscription_create.username, - application=subscription_create.application, - active=subscription_create.active, + owner_id=subscription_create.owner_id, + application_id=subscription_create.application_id, + pricing_id=subscription_create.pricing_id, tier=subscription_create.tier, credit=subscription_pricing.credit, balance=subscription_pricing.credit, expires_at=subscription_create.expires_at, recurring=subscription_create.recurring, - created_by=subscription_create.created_by, ) try: self.session.add(new_subscription) @@ -202,12 +211,76 @@ def create_subscription(self, subscription_create: SubscriptionCreate): self.session.rollback() raise SubscriptionException(f"Error creating subscription: {e}") + def get_active_subscription_by_name( + self, owner_id: int, application: str + ) -> SubscriptionDetails: + """ + Get active subscription of a user. + :param email: str + :param application: str + :return: SubscriptionDetails object. + """ + try: + application = self.session.query(Application).filter( + Application.name == application + ).one() + + subscription = self.get_query().filter( + Subscription.owner_id == owner_id, + Subscription.application_id == application.id, + Subscription.is_active == true(), + or_( + Subscription.expires_at.is_(None), + Subscription.expires_at > datetime.now(), + ), + ).one() + except NoResultFound: + raise SubscriptionException("Subscription not found.") + + return SubscriptionDetails( + id=subscription.id, + owner_id=subscription.owner_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, + tier=subscription.tier, + is_active=subscription.is_active, + credit=subscription.credit, + balance=subscription.balance, + starts_at=subscription.starts_at, + expires_at=subscription.expires_at, + recurring=subscription.recurring, + created_at=subscription.created_at, + ) + + def get_subscription(self, subscription_id: int) -> SubscriptionDetails: + try: + subscription = self.get_query().filter( + Subscription.id == subscription_id + ).one() + except NoResultFound: + raise SubscriptionException("Subscription not found.") + + return SubscriptionDetails( + id=subscription.id, + owner_id=subscription.owner_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, + tier=subscription.tier, + is_active=subscription.is_active, + credit=subscription.credit, + balance=subscription.balance, + starts_at=subscription.starts_at, + expires_at=subscription.expires_at, + recurring=subscription.recurring, + created_at=subscription.created_at, + ) + def get_active_subscription( - self, username: str, application: str + self, owner_id: int, application_id: int ) -> SubscriptionDetails: """ Get active subscription of a user. - :param username: str + :param email: str :param application: str :return: SubscriptionDetails object. """ @@ -215,9 +288,9 @@ def get_active_subscription( subscription = ( self.get_query() .filter( - Subscription.username == username, - Subscription.application == application, - Subscription.active == true(), + Subscription.owner_id == owner_id, + Subscription.application_id == application_id, + Subscription.is_active == true(), or_( Subscription.expires_at.is_(None), Subscription.expires_at > datetime.now(), @@ -229,29 +302,30 @@ def get_active_subscription( raise SubscriptionException return SubscriptionDetails( - username=username, - application=application, + id=subscription.id, + owner_id=subscription.owner_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, - active=subscription.active, + is_active=subscription.is_active, credit=subscription.credit, balance=subscription.balance, starts_at=subscription.starts_at, expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=subscription.created_by, created_at=subscription.created_at, ) - def get_active_subscriptions(self, username: str) -> List[SubscriptionDetails]: + def get_active_subscriptions(self, user_id: int) -> List[SubscriptionDetails]: """ Get all active subscriptions of a user. - :param username: str + :param email: str :return: list of SubscriptionDetails. """ try: subscriptions = self.get_query().filter( - Subscription.username == username, - Subscription.active == true(), + Subscription.owner_id == user_id, + Subscription.is_active == true(), or_( Subscription.expires_at.is_(None), Subscription.expires_at > datetime.now(), @@ -262,51 +336,35 @@ def get_active_subscriptions(self, username: str) -> List[SubscriptionDetails]: return [ SubscriptionDetails( - username=subscription.username, - application=subscription.application, + id=subscription.id, + owner_id=subscription.owner_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, - active=subscription.active, + is_active=subscription.is_active, credit=subscription.credit, balance=subscription.balance, expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=subscription.created_by, created_at=subscription.created_at, ) for subscription in subscriptions ] def update_balance_in_subscription( - self, username: str, application: str, tier: str, redis: Redis + self, subscription: Subscription, redis: Redis ) -> None: """ Update balance in subscription. - :param username: str + :param email: str :param application: str :param tier: str :param redis: Redis object. :return: None """ - try: - subscription = ( - self.get_query() - .filter( - Subscription.username == username, - Subscription.application == application, - Subscription.tier == tier, - Subscription.active == true(), - or_( - Subscription.expires_at.is_(None), - Subscription.expires_at > datetime.now(), - ), - ) - .one() - ) - except NoResultFound: - raise SubscriptionException - with get_and_reset_balance_in_cache( - username, application, tier, redis + # FIXME why not use subscription.id? + subscription.user_id, subscription.application_id, subscription.tier, redis ) as balance: subscription.balance = subscription.credit - balance self.session.add(subscription) diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index fc4f225..ab53fe9 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -7,9 +7,9 @@ from ..common.db_session import create_session from ..security.schemas import ( - UserBase, + UserBaseWithId, ) -from ..security.depends import require_admin, require_publisher, require_token +from ..security.depends import require_admin, require_publisher, require_token, require_user, require_logged_in from ..security.queries import UserQuery, UserException from .schemas import ( @@ -21,7 +21,7 @@ from .queries import ( SubscriptionQuery, SubscriptionException, - SubscriptionPricingException, + PricingException, ApplicationQuery, ApplicationException, ) @@ -41,14 +41,14 @@ class SubscriptionSettings(BaseSettings): def create_application( application: ApplicationCreate, session: Session = Depends(create_session), - username: str = Depends(require_publisher), + publisher: str = Depends(require_publisher), ): """ Create an application. """ applicationCreateWithOwner = ApplicationCreateWithOwner.copy( application, - update={"owner": username} + update={"owner": publisher} ) try: @@ -60,7 +60,7 @@ def create_application( @router.get("/application", response_model=List[ApplicationCreate]) def get_applications( session: Session = Depends(create_session), - user: UserBase = Depends(require_token), + user: str = Depends(require_logged_in), ): """ List all applications. @@ -77,44 +77,44 @@ def get_applications( def get_application( application: str, session: Session = Depends(create_session), - username: str = Depends(require_admin), + user: str = Depends(require_logged_in), ): try: """ Get an application. """ - return ApplicationQuery(session).get_application(application) + return ApplicationQuery(session).get_application_by_name(application) except ApplicationException: - raise HTTPException(400, "Error while retrieving applications") + raise HTTPException(400, f"Error while retrieving application {application}") @router.post("/subscription") def create_subscription( subscription: SubscriptionIn, - username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): - # make sure the username exists. + # make sure the email exists. try: - UserQuery(session).get_user_by_username(subscription.username) + UserQuery(session).get_user_by_id(subscription.owner_id) except UserException: - raise HTTPException(401, f"User {subscription.username} not found.") + raise HTTPException(401, f"User {subscription.owner_id} not found.") # make sure the application is not currently active. try: SubscriptionQuery(session).get_active_subscription( - subscription.username, subscription.application + subscription.owner_id, subscription.application_id ) raise HTTPException( - 403, f"Application {subscription.application} already exists." + 403, f"Subscription for applicaiton {subscription.application_id} already exists." ) except SubscriptionException: pass try: - ApplicationQuery(session).get_application(subscription.application) + ApplicationQuery(session).get_application(subscription.application_id) except ApplicationException: - raise HTTPException(404, f"Application {subscription.application} not found.") + raise HTTPException(404, f"Application {subscription.application_id} not found.") if subscription.expires_at is None: subscription.expires_at = datetime.now() + timedelta( @@ -122,13 +122,13 @@ def create_subscription( ) subscription_create = SubscriptionCreate( - username=subscription.username, - application=subscription.application, + owner_id=subscription.owner_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, starts_at=datetime.now(), expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=username, ) try: query = SubscriptionQuery(session) @@ -136,19 +136,19 @@ def create_subscription( return subscription_create except SubscriptionException as e: raise HTTPException(400, str(e)) - except SubscriptionPricingException as e: + except PricingException as e: raise HTTPException(400, str(e)) @router.get("/subscription/{application}") def get_active_subscription( - application: str, - user: UserBase = Depends(require_token), + application: int, + user: UserBaseWithId = Depends(require_logged_in), session=Depends(create_session), ): query = SubscriptionQuery(session) try: - subscription = query.get_active_subscription(user.username, application) + subscription = query.get_active_subscription(user.id, application) except SubscriptionException: raise HTTPException(400, "Subscription not found") @@ -157,33 +157,23 @@ def get_active_subscription( @router.get("/subscription") def get_active_subscriptions( - user: UserBase = Depends(require_token), + user: UserBaseWithId = Depends(require_user), session=Depends(create_session), ): if not user.is_user: return [] - username = user.username query = SubscriptionQuery(session) try: - subscriptions = query.get_active_subscriptions(username) + subscriptions = query.get_active_subscriptions(user.id) except SubscriptionException: return [] - return [ - SubscriptionIn( - username=subscription.username, - application=subscription.application, - tier=subscription.tier, - expires_at=subscription.expires_at, - recurring=subscription.recurring, - ) - for subscription in subscriptions - ] + return subscriptions class SubscriptionTokenResponse(BaseModel): - username: str + email: str application: str token: str expires_time: int @@ -192,8 +182,8 @@ class SubscriptionTokenResponse(BaseModel): @router.get("/token/{application}") async def get_application_token( application: str, - user: UserBase = Depends(require_token), - username: Optional[str] = None, + user: UserBaseWithId = Depends(require_user), + email: Optional[str] = None, expires_days: Optional[ int ] = SubscriptionSettings().subscription_token_expires_days, @@ -202,16 +192,16 @@ async def get_application_token( query = SubscriptionQuery(session) if user.is_user: - username = user.username + email = user.email expires_days = SubscriptionSettings().subscription_token_expires_days else: - if username is None: - raise HTTPException(401, "username is missing") + if email is None: + raise HTTPException(401, "email is missing") try: - subscription = query.get_active_subscription(username, application) + subscription = query.get_active_subscription_by_name(user.id, application) except SubscriptionException: - raise HTTPException(401, f"No active subscription found for user {username}") + raise HTTPException(401, f"No active subscription found for user {email}") if subscription.balance > subscription.credit: raise HTTPException(HTTP_429_TOO_MANY_REQUESTS, "You have used up your credit") @@ -224,12 +214,18 @@ async def get_application_token( Authorize = AuthJWT() expires_time = timedelta(days=expires_days) access_token = Authorize.create_access_token( - subject=username, - user_claims={"subscription": application, "tier": subscription.tier}, + subject=email, + user_claims={ + "subscription": application, + "tier": subscription.tier, + "user_id": user.id, + "subscription_id": subscription.id, + "application_id": subscription.application_id, + }, expires_time=expires_time, ) return SubscriptionTokenResponse( - username=username, + email=email, application=application, token=access_token, expires_time=expires_time.seconds, diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index d0573a8..1e87a6e 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -10,17 +10,17 @@ class SubscriptionTier(str, Enum): STANDARD = "STANDARD" PREMIUM = "PREMIUM" -class SubscriptionPricingBase(BaseModel): +class PricingBase(BaseModel): tier: SubscriptionTier price: int credit: int -class SubscriptionPricingCreate(SubscriptionPricingBase): +class PricingCreate(PricingBase): application: str -class SubscriptionPricingDetails(SubscriptionPricingCreate): +class PricingDetails(PricingCreate): id: int @@ -31,20 +31,25 @@ class ApplicationBase(BaseModel): class ApplicationCreate(ApplicationBase): - pricing: List[SubscriptionPricingBase] + pricings: List[PricingBase] class ApplicationCreateWithOwner(ApplicationCreate): - owner: str + owner_id: int + + +class ApplicationDetailsWithId(ApplicationCreateWithOwner): + id: int class SubscriptionBase(BaseModel): - username: str - application: str + owner_id: int + application_id: int tier: SubscriptionTier class SubscriptionIn(SubscriptionBase): + pricing_id: int expires_at: Optional[datetime] = None recurring: bool = False @@ -52,12 +57,10 @@ class SubscriptionIn(SubscriptionBase): class SubscriptionCreate(SubscriptionIn): credit: Optional[int] = 0 starts_at: datetime = datetime.now() - active: bool = True - created_by: str notes: Optional[str] = None - application: str class SubscriptionDetails(SubscriptionCreate): + id: int created_at: datetime - balance: int + balance: int \ No newline at end of file diff --git a/tests/test_activity.py b/tests/test_activity.py deleted file mode 100644 index 27a3bbe..0000000 --- a/tests/test_activity.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -from datetime import datetime -import factory -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from apihub.common.db_session import create_session -from apihub.subscription.router import router -from apihub.subscription.models import SubscriptionTier -from apihub.activity.models import Activity -from apihub.activity.queries import ActivityQuery, ActivityException -from apihub.activity.schemas import ActivityCreate, ActivityStatus - - -class ActivityFactory(factory.alchemy.SQLAlchemyModelFactory): - class Meta: - model = Activity - - id = factory.Sequence(int) - created_at = factory.LazyFunction(datetime.now) - request = factory.Sequence(lambda n: f"app{n}") - username = factory.Sequence(lambda n: f"tester{n}") - tier = SubscriptionTier.TRIAL - status = ActivityStatus.PROCESSED - request_key = factory.Sequence(lambda n: f"request_key{n}") - result = "" - payload = "" - ip_address = "" - latency = 0.0 - - -@pytest.fixture(scope="function") -def client(db_session): - def _create_session(): - try: - yield db_session - finally: - pass - - app = FastAPI() - app.include_router(router) - app.dependency_overrides[create_session] = _create_session - ActivityFactory._meta.sqlalchemy_session = db_session - ActivityFactory._meta.sqlalchemy_session_persistence = "commit" - ActivityFactory( - username="tester", - request="async/app1", - request_key="app1_key", - status=ActivityStatus.PROCESSED, - ) - yield TestClient(app) - - -@pytest.fixture(scope="function") -def query(db_session): - yield ActivityQuery(db_session) - - -class TestActivity: - def test_create_activity(self, query): - query.create_activity( - ActivityCreate( - request="async/test", - username="ahmed", - tier=SubscriptionTier.TRIAL, - status=ActivityStatus.ACCEPTED, - request_key="async/test_key1234", - result="", - payload="", - ip_address="", - latency=0.0, - ) - ) - - assert ( - query.get_activity_by_key("async/test_key1234").request_key - == "async/test_key1234" - ) - - def test_get_activity_by_key(self, client, query): - assert query.get_activity_by_key("app1_key").request_key == "app1_key" - - with pytest.raises(ActivityException): - query.get_activity_by_key("key 2") - - def test_update_activity(self, client, query): - activity = query.get_activity_by_key("app1_key") - assert activity.tier == SubscriptionTier.TRIAL - - query.update_activity( - "app1_key", - **{"tier": SubscriptionTier.STANDARD, "ip_address": "test ip"}, - ) - - activity = query.get_activity_by_key("app1_key") - assert ( - activity.tier == SubscriptionTier.STANDARD - and activity.ip_address == "test ip" - and activity.latency > 0.0 - ) - - with pytest.raises(ActivityException): - query.update_activity( - "not existing", - **{"tier": SubscriptionTier.STANDARD, "ip_address": "test ip"}, - ) diff --git a/tests/test_result.py b/tests/test_result.py index 4898fcd..8242459 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -1,30 +1,30 @@ import pytest -from apihub.activity.queries import ActivityQuery -from apihub.activity.schemas import ActivityStatus -from .test_activity import ActivityFactory +# from apihub.activity.queries import ActivityQuery +# from apihub.activity.schemas import ActivityStatus +# from .test_activity import ActivityFactory message_id = "ab7fe542-bdf2-11eb-b401-f21898b454f0" -@pytest.fixture(scope="function") -def query(db_session): - ActivityFactory._meta.sqlalchemy_session = db_session - ActivityFactory._meta.sqlalchemy_session_persistence = "commit" - ActivityFactory( - username="tester", - request="async/app1", - request_key=message_id, - status=ActivityStatus.ACCEPTED, - ) - yield ActivityQuery(db_session) +# @pytest.fixture(scope="function") +# def query(db_session): +# ActivityFactory._meta.sqlalchemy_session = db_session +# ActivityFactory._meta.sqlalchemy_session_persistence = "commit" +# ActivityFactory( +# username="tester", +# request="async/app1", +# request_key=message_id, +# status=ActivityStatus.ACCEPTED, +# ) +# yield ActivityQuery(db_session) class TestResultWriter: - def test_basic(self, db_session, query, monkeypatch): - activity = query.get_activity_by_key(message_id) - assert activity.status == ActivityStatus.ACCEPTED + def test_basic(self, db_session, monkeypatch): + # activity = query.get_activity_by_key(message_id) + # assert activity.status == ActivityStatus.ACCEPTED monkeypatch.setenv("MONITORING", "FALSE") from apihub.result import ResultWriter @@ -39,8 +39,8 @@ def test_basic(self, db_session, query, monkeypatch): except Exception: pytest.fail("worker raised exception") - activity = query.get_activity_by_key(message_id) - assert activity.status == ActivityStatus.PROCESSED + # activity = query.get_activity_by_key(message_id) + # assert activity.status == ActivityStatus.PROCESSED def test_command(self, monkeypatch): monkeypatch.setenv("MONITORING", "FALSE") diff --git a/tests/test_security.py b/tests/test_security.py index 921e8a1..4d4c02a 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -29,7 +29,8 @@ class Meta: model = User id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") + name = factory.Sequence(lambda n: f"Mr. Tester{n}") + email = factory.Sequence(lambda n: f"tester{n}@tester.com") salt = SALT hashed_password = itemgetter(1)(hash_password("password", salt=SALT)) role = UserType.USER @@ -40,18 +41,18 @@ def test_user_create(db_session): query = UserQuery(db_session) query.create_user( user=UserCreate( - username="tester", + name="Mr. Tester", email="newuser@test.com", password="testpassword", role=UserType.USER, ) ) - user = query.get_user_by_username(username="tester") + user = query.get_user_by_email(email="newuser@test.com") assert user is not None another_user = query.get_user_by_id(user_id=user.id) - assert user.username == another_user.username + assert user.email == another_user.email @pytest.fixture(scope="function") @@ -80,48 +81,48 @@ def protected(Authorize: AuthJWT = Depends()): return {"user": user, "role": role} @app.get("/admin") - def admin(username=Depends(require_admin)): - return username + def admin(email=Depends(require_admin)): + return email app.dependency_overrides[create_session] = _create_session UserFactory._meta.sqlalchemy_session = db_session UserFactory._meta.sqlalchemy_session_persistence = "commit" - UserFactory(username="tester", role=UserType.USER) - UserFactory(username="admin", role=UserType.ADMIN) - UserFactory(username="publisher", role=UserType.PUBLISHER) - UserFactory(username="user", role=UserType.USER) - UserFactory(username="app", role=UserType.APP) + UserFactory(email="tester@test.com", role=UserType.USER) + UserFactory(email="admin@test.com", role=UserType.ADMIN) + UserFactory(email="publisher@test.com", role=UserType.PUBLISHER) + UserFactory(email="user@test.com", role=UserType.USER) + UserFactory(email="app@test.com", role=UserType.APP) yield TestClient(app) class TestAuthenticate: - def _make_auth_header(self, username, password): + def _make_auth_header(self, email, password): from base64 import b64encode - raw = b64encode(f"{username}:{password}".encode("ascii")).decode("ascii") + raw = b64encode(f"{email}:{password}".encode("ascii")).decode("ascii") return {"Authorization": f"Basic {raw}"} def test_authenticate_wrong_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("nosuchuser", "password"), + headers=self._make_auth_header("nosuchuser@test.com", "password"), ) assert response.status_code == 403 def test_authenticate_wrong_password(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "nosuchpassword"), + headers=self._make_auth_header("tester@test.com", "nosuchpassword"), ) assert response.status_code == 403 def test_authenticate(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "password"), + headers=self._make_auth_header("tester@test.com", "password"), params={"expires_days": 2}, ) assert response.status_code == 200 @@ -136,7 +137,7 @@ def test_pretected_no_token(self, client): def test_token(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "password"), + headers=self._make_auth_header("tester@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -149,7 +150,7 @@ def test_token(self, client): def test_require_admin_when_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -160,7 +161,7 @@ def test_require_admin_when_admin(self, client): def test_require_admin_when_manager(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("publisher", "password"), + headers=self._make_auth_header("publisher@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -171,24 +172,18 @@ def test_require_admin_when_manager(self, client): def test_create_and_get_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) token = auth_response.access_token - response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) + response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 - assert response.json().get("username") == "admin" - - response = client.get( - "/user/user", headers={"Authorization": f"Bearer {token}"} - ) - assert response.status_code == 200 - assert response.json().get("username") == "user" + assert response.json().get("email") == "admin@test.com" new_user = UserCreate( - username="newuser", + name="New User", email="newuser@test.com", password="password", role="user", @@ -199,17 +194,25 @@ def test_create_and_get_user(self, client): json=new_user.dict(), ) assert response.status_code == 200 + + response = client.get( + "/_authenticate", + headers=self._make_auth_header("newuser@test.com", "password"), + ) + assert response.status_code == 200 + auth_response = AuthenticateResponse.parse_obj(response.json()) + token = auth_response.access_token response = client.get( - f"/user/{new_user.username}", headers={"Authorization": f"Bearer {token}"} + f"/user", headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200 - assert response.json().get("username") == new_user.username + assert response.json().get("email") == new_user.email def test_get_users(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -218,7 +221,7 @@ def test_get_users(self, client): response = client.get( "/user", headers={"Authorization": f"Bearer {token}"}, - json={"usernames": "admin,publisher,user"}, + json={"emails": "admin@test.com,publisher@test.com,user@test.com"}, ) assert response.status_code == 200 assert len(response.json()) == 3 @@ -226,7 +229,7 @@ def test_get_users(self, client): def test_list_users(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -242,7 +245,7 @@ def test_list_users(self, client): def test_change_password_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -257,13 +260,13 @@ def test_change_password_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 403 response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "newpassword"), + headers=self._make_auth_header("user@test.com", "newpassword"), ) assert response.status_code == 200 @@ -277,7 +280,7 @@ def test_change_password_user(self, client): def test_change_password_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) @@ -285,12 +288,12 @@ def test_change_password_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 200 response = client.post( - "/user/user/_password", + "/user/user@test.com/_password", headers={"Authorization": f"Bearer {token}"}, json={"password": "newpassword"}, ) @@ -298,18 +301,18 @@ def test_change_password_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 403 response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "newpassword"), + headers=self._make_auth_header("user@test.com", "newpassword"), ) assert response.status_code == 200 response = client.post( - "/user/user/_password", + "/user/user@test.com/_password", headers={"Authorization": f"Bearer {token}"}, json={"password": "password"}, ) @@ -318,14 +321,14 @@ def test_change_password_admin(self, client): def test_register(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("app", "password"), + headers=self._make_auth_header("app@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) token = auth_response.access_token new_user = UserRegister( - username="newuser", + name="New User", email="newuser@test.com", password="password", ) @@ -338,12 +341,12 @@ def test_register(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("newuser", "password"), + headers=self._make_auth_header("newuser@test.com", "password"), ) assert response.status_code == 200 auth_response = AuthenticateResponse.parse_obj(response.json()) token = auth_response.access_token - response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) + response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 - assert response.json().get("username") == new_user.username + assert response.json().get("email") == new_user.email diff --git a/tests/test_server.py b/tests/test_server.py index 5b715ff..65bda16 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,8 +8,8 @@ from openapi_spec_validator import validate_spec, openapi_v30_spec_validator from apihub.common.db_session import create_session -from apihub.subscription.depends import require_subscription -from apihub.subscription.schemas import SubscriptionTier, SubscriptionBase +from apihub.subscription.depends import require_subscription, SubscriptionResponse +from apihub.subscription.schemas import SubscriptionTier from apihub.utils import make_topic @@ -22,8 +22,9 @@ def _ip_rate_limited(): pass def _require_subscription(application:str): - return SubscriptionBase( - username="test", tier=SubscriptionTier.TRIAL, application=application, + return SubscriptionResponse( + user_id=1, subscription_id=1, application_id=1, + email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", ) monkeypatch.setenv("OUT_KIND", "MEM") diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 8fe63ba..71244c6 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -9,22 +9,23 @@ from apihub.common.db_session import create_session from apihub.security.models import User -from apihub.security.schemas import UserBase, UserType -from apihub.security.depends import require_user, require_admin, require_token, require_publisher +from apihub.security.schemas import UserBase, UserType, UserBaseWithId +from apihub.security.depends import require_user, require_admin, require_token, require_publisher, require_logged_in from apihub.subscription.depends import ( require_subscription_balance, + SubscriptionResponse, ) from apihub.subscription.models import ( Subscription, SubscriptionTier, Application, - SubscriptionPricing, + Pricing, ) from apihub.subscription.router import router from apihub.subscription.schemas import ( SubscriptionIn, ApplicationCreate, - SubscriptionPricingBase, + PricingBase, ) from apihub.security.helpers import hash_password @@ -44,18 +45,18 @@ class Meta: description = "description" created_at = factory.LazyFunction(datetime.now) - owner = "tester" + owner_id = factory.Sequence(int) -class SubscriptionPricingFactory(factory.alchemy.SQLAlchemyModelFactory): +class PricingFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: - model = SubscriptionPricing + model = Pricing id = factory.Sequence(int) tier = SubscriptionTier.TRIAL price = 100.0 credit = 100.0 - application = factory.Sequence(lambda n: f"app{n}") + application_id = factory.Sequence(int) class UserFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -63,7 +64,8 @@ class Meta: model = User id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") + name = factory.Sequence(lambda n: f"Mr Tester{n}") + email = factory.Sequence(lambda n: f"tester{n}@test.com") salt = SALT hashed_password = itemgetter(1)(hash_password("password", salt=SALT)) role = UserType.USER @@ -75,9 +77,7 @@ class Meta: model = Subscription id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") - application = "test" - active = True + is_active = True tier = SubscriptionTier.TRIAL credit = 100 balance = 0 @@ -85,20 +85,24 @@ class Meta: expires_at = factory.LazyFunction(lambda: datetime.now() + timedelta(days=1)) recurring = False created_at = factory.LazyFunction(datetime.now) - created_by = "admin" + # created_by = "admin" notes = None + owner_id = factory.Sequence(int) + application_id = factory.Sequence(int) + pricing_id = factory.Sequence(int) + def _require_admin_token(): - return UserBase(username="tester", role=UserType.ADMIN) + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.ADMIN) def _require_user_token(): - return UserBase(username="tester", role=UserType.USER) + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.USER) def _require_publisher_token(): - return UserBase(username="tester", role=UserType.MANAGER) + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.PUBLISHER) @pytest.fixture(scope="function") @@ -109,55 +113,47 @@ def _create_session(): finally: pass - def _require_admin(): - return "admin" - - def _require_user(): - return "user" - - def _require_publisher(): - return "publisher" - app = FastAPI() app.include_router(router) app.dependency_overrides[create_session] = _create_session - app.dependency_overrides[require_admin] = _require_admin - app.dependency_overrides[require_user] = _require_user - app.dependency_overrides[require_publisher] = _require_publisher + app.dependency_overrides[require_admin] = _require_admin_token + app.dependency_overrides[require_user] = _require_user_token + app.dependency_overrides[require_publisher] = _require_publisher_token app.dependency_overrides[require_token] = _require_user_token + app.dependency_overrides[require_logged_in] = _require_user_token @app.get("/api_balance/{application}") def api_function_2( - application: str, username: str = Depends(require_subscription_balance) + application: str, subscription: SubscriptionResponse = Depends(require_subscription_balance) ): pass UserFactory._meta.sqlalchemy_session = db_session UserFactory._meta.sqlalchemy_session_persistence = "commit" - UserFactory(username="tester", role=UserType.USER) + tester = UserFactory(id=100, email="tester@test.com", role=UserType.USER) UserFactory._meta.sqlalchemy_session = db_session UserFactory._meta.sqlalchemy_session_persistence = "commit" - UserFactory(username="publisher", role=UserType.PUBLISHER) + publisher = UserFactory(id=200, email="publisher@test.com", role=UserType.PUBLISHER) ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - application = ApplicationFactory(name="test", url="/test") + application = ApplicationFactory(id=100, name="test", url="/test", owner_id=publisher.id) - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - pricing = SubscriptionPricingFactory( + PricingFactory._meta.sqlalchemy_session = db_session + PricingFactory._meta.sqlalchemy_session_persistence = "commit" + pricing = PricingFactory( + id=100, tier=SubscriptionTier.TRIAL, price=100, credit=100, - application="test", + application_id=application.id, ) SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - - SubscriptionFactory(username="tester", application="test", credit=100) + SubscriptionFactory(owner_id=tester.id, application_id=application.id, credit=100, pricing=pricing) yield TestClient(app) @@ -168,14 +164,14 @@ def test_create_application(self, client): name="app", url="/test", description="test", - pricing=[ - SubscriptionPricingBase( + pricings=[ + PricingBase( tier=SubscriptionTier.TRIAL, price=100, credit=100 ), - SubscriptionPricingBase( + PricingBase( tier=SubscriptionTier.STANDARD, price=200, credit=200 ), - SubscriptionPricingBase( + PricingBase( tier=SubscriptionTier.PREMIUM, price=300, credit=300 ), ], @@ -191,7 +187,7 @@ def test_create_application(self, client): ) response_json = response.json() - assert len(response_json["pricing"]) == 3 + assert len(response_json["pricings"]) == 3 def test_list_application(self, client, db_session): response = client.get("/application") @@ -206,30 +202,34 @@ def test_get_application(self, client, db_session): assert response.status_code == 200 response_json = response.json() assert ( - len(response_json["pricing"]) == 1 - and response_json["pricing"][0]["tier"] == "TRIAL" + len(response_json["pricings"]) == 1 + and response_json["pricings"][0]["tier"] == "TRIAL" ) class TestSubscription: def test_create_and_get_subscription(self, client, db_session): + UserFactory._meta.sqlalchemy_session = db_session + UserFactory._meta.sqlalchemy_session_persistence = "commit" + publisher = UserFactory(email="publisher1@test.com", role=UserType.PUBLISHER) + ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="application", url="/test") + application = ApplicationFactory(name="application", url="/test", owner_id=publisher.id) - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - SubscriptionPricingFactory( + PricingFactory._meta.sqlalchemy_session = db_session + PricingFactory._meta.sqlalchemy_session_persistence = "commit" + pricing = PricingFactory( tier=SubscriptionTier.TRIAL, price=100, credit=100, - application="application", ) # case 1: create subscription new_subscription = SubscriptionIn( - username="tester", - application="application", + owner_id=publisher.id, + application_id=application.id, + pricing_id=pricing.id, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -240,14 +240,24 @@ def test_create_and_get_subscription(self, client, db_session): ) assert response.status_code == 200 + def _require_logged_in(): + return publisher + + client.app.dependency_overrides[require_logged_in] = _require_logged_in + response = client.get( - "/subscription/application", + f"/subscription/{application.id}", ) assert response.status_code == 200 assert response.json().get("credit") == 100 - assert response.json().get("active") is True def test_get_all_subscriptions(self, client): + + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) + + client.app.dependency_overrides[require_user] = _require_user + response = client.get( "/subscription", ) @@ -257,8 +267,9 @@ def test_get_all_subscriptions(self, client): def test_create_subscription_not_existing_user(self, client): new_subscription = SubscriptionIn( - username="not existing user", - application="app 1", + owner_id=-1, + application_id=1, + pricing_id=1, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -270,61 +281,31 @@ def test_create_subscription_not_existing_user(self, client): assert response.status_code == 401 def test_get_application_token(self, client, db_session): - SubscriptionFactory._meta.sqlalchemy_session = db_session - SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) - ApplicationFactory(name="app") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app" - ) - SubscriptionFactory(username="tester", application="app", credit=100) + client.app.dependency_overrides[require_user] = _require_user - response = client.get( - "/token/app", - ) - assert response.status_code == 200, response.json() - assert response.json().get("token") is not None - - ApplicationFactory(name="app_2") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app_2" - ) - SubscriptionFactory( - username="tester", application="app_2", active=False, credit=1000 - ) - - response = client.get( - "/subscription/app_2", - ) - - assert response.status_code == 400, response.json() - - def test_get_application_token_admin(self, client, db_session): - client.app.dependency_overrides[require_token] = _require_admin_token SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="app2") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app2" + application = ApplicationFactory(name="app", owner_id=100) + pricing = PricingFactory( + tier=SubscriptionTier.TRIAL, price=100, credit=100, application_id=application.id ) - SubscriptionFactory(username="tester", application="app2", credit=1000) + SubscriptionFactory(owner_id=100, application_id=application.id, pricing_id=pricing.id, credit=100) response = client.get( - "/token/app2", - params={ - "username": "tester", - "expires_days": 30, - }, + "/token/app", ) - client.app.dependency_overrides[require_token] = _require_user_token assert response.status_code == 200, response.json() assert response.json().get("token") is not None def test_create_duplicate_subscription(self, client, db_session): new_subscription = SubscriptionIn( - username="tester", - application="application", + owner_id=100, + application_id=100, + pricing_id=100, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -333,62 +314,24 @@ def test_create_duplicate_subscription(self, client, db_session): "/subscription", data=new_subscription.json(), ) - assert response.status_code == 404 + assert response.status_code == 403 - ApplicationFactory._meta.sqlalchemy_session = db_session - ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="application", url="/test") - - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, - price=100, - credit=100, - application="application", - ) - - new_subscription = SubscriptionIn( - username="tester", - application="application", - tier=SubscriptionTier.TRIAL, - expires_at=None, - recurring=False, - ) - response = client.post( - "/subscription", - data=new_subscription.json(), - ) - assert response.status_code == 200, response.json() + def test_require_balance(self, client, db_session): + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) - response = client.post( - "/subscription", - data=new_subscription.json(), - ) - assert response.status_code == 403, response.json() + client.app.dependency_overrides[require_user] = _require_user - def test_require_balance(self, client, db_session): SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="app3") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app3" - ) - SubscriptionFactory( - username="tester", - application="app3", - tier=SubscriptionTier.TRIAL, - credit=2, - ) - response = client.get( - "/token/app3", + "/token/test", ) assert response.status_code == 200, response.json() token = response.json().get("token") response = client.get( - "/api_balance/app3", headers={"Authorization": f"Bearer {token}"} + "/api_balance/test", headers={"Authorization": f"Bearer {token}"} ) - assert response.status_code == 200, response.json() + assert response.status_code == 200, response.json() \ No newline at end of file From a26c3faefb4252c0cdbc2666b46d148e250ff982 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Tue, 24 Jan 2023 16:17:53 +0300 Subject: [PATCH 07/17] rename owner_id to user_id --- apihub/subscription/models.py | 4 ++-- apihub/subscription/queries.py | 22 +++++++++++----------- apihub/subscription/router.py | 8 ++++---- apihub/subscription/schemas.py | 4 ++-- tests/test_subscription.py | 20 ++++++++++---------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index 0a7b9d3..21a80d1 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -30,7 +30,7 @@ class Application(Base): is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.now()) - owner_id = Column(Integer, ForeignKey("users.id")) + user_id = Column(Integer, ForeignKey("users.id")) owner = relationship("User") subscriptions = relationship("Subscription", back_populates="application") @@ -105,7 +105,7 @@ class Subscription(Base): # created_by = Column(Integer, ForeignKey("users.id")) notes = Column(String) - owner_id = Column(Integer, ForeignKey("users.id")) + user_id = Column(Integer, ForeignKey("users.id")) owner = relationship("User") application_id = Column(Integer, ForeignKey("applications.id"), nullable=False) diff --git a/apihub/subscription/queries.py b/apihub/subscription/queries.py index b3fe3c7..7656f09 100644 --- a/apihub/subscription/queries.py +++ b/apihub/subscription/queries.py @@ -178,7 +178,7 @@ def create_subscription(self, subscription_create: SubscriptionCreate): found_existing_subscription = True try: self.get_active_subscription( - subscription_create.owner_id, subscription_create.application_id + subscription_create.user_id, subscription_create.application_id ) except SubscriptionException: found_existing_subscription = False @@ -195,7 +195,7 @@ def create_subscription(self, subscription_create: SubscriptionCreate): ) new_subscription = Subscription( - owner_id=subscription_create.owner_id, + user_id=subscription_create.user_id, application_id=subscription_create.application_id, pricing_id=subscription_create.pricing_id, tier=subscription_create.tier, @@ -212,7 +212,7 @@ def create_subscription(self, subscription_create: SubscriptionCreate): raise SubscriptionException(f"Error creating subscription: {e}") def get_active_subscription_by_name( - self, owner_id: int, application: str + self, user_id: int, application: str ) -> SubscriptionDetails: """ Get active subscription of a user. @@ -226,7 +226,7 @@ def get_active_subscription_by_name( ).one() subscription = self.get_query().filter( - Subscription.owner_id == owner_id, + Subscription.user_id == user_id, Subscription.application_id == application.id, Subscription.is_active == true(), or_( @@ -239,7 +239,7 @@ def get_active_subscription_by_name( return SubscriptionDetails( id=subscription.id, - owner_id=subscription.owner_id, + user_id=subscription.user_id, application_id=subscription.application_id, pricing_id=subscription.pricing_id, tier=subscription.tier, @@ -262,7 +262,7 @@ def get_subscription(self, subscription_id: int) -> SubscriptionDetails: return SubscriptionDetails( id=subscription.id, - owner_id=subscription.owner_id, + user_id=subscription.user_id, application_id=subscription.application_id, pricing_id=subscription.pricing_id, tier=subscription.tier, @@ -276,7 +276,7 @@ def get_subscription(self, subscription_id: int) -> SubscriptionDetails: ) def get_active_subscription( - self, owner_id: int, application_id: int + self, user_id: int, application_id: int ) -> SubscriptionDetails: """ Get active subscription of a user. @@ -288,7 +288,7 @@ def get_active_subscription( subscription = ( self.get_query() .filter( - Subscription.owner_id == owner_id, + Subscription.user_id == user_id, Subscription.application_id == application_id, Subscription.is_active == true(), or_( @@ -303,7 +303,7 @@ def get_active_subscription( return SubscriptionDetails( id=subscription.id, - owner_id=subscription.owner_id, + user_id=subscription.user_id, application_id=subscription.application_id, pricing_id=subscription.pricing_id, tier=subscription.tier, @@ -324,7 +324,7 @@ def get_active_subscriptions(self, user_id: int) -> List[SubscriptionDetails]: """ try: subscriptions = self.get_query().filter( - Subscription.owner_id == user_id, + Subscription.user_id == user_id, Subscription.is_active == true(), or_( Subscription.expires_at.is_(None), @@ -337,7 +337,7 @@ def get_active_subscriptions(self, user_id: int) -> List[SubscriptionDetails]: return [ SubscriptionDetails( id=subscription.id, - owner_id=subscription.owner_id, + user_id=subscription.user_id, application_id=subscription.application_id, pricing_id=subscription.pricing_id, tier=subscription.tier, diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index ab53fe9..2e6276c 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -96,14 +96,14 @@ def create_subscription( ): # make sure the email exists. try: - UserQuery(session).get_user_by_id(subscription.owner_id) + UserQuery(session).get_user_by_id(subscription.user_id) except UserException: - raise HTTPException(401, f"User {subscription.owner_id} not found.") + raise HTTPException(401, f"User {subscription.user_id} not found.") # make sure the application is not currently active. try: SubscriptionQuery(session).get_active_subscription( - subscription.owner_id, subscription.application_id + subscription.user_id, subscription.application_id ) raise HTTPException( 403, f"Subscription for applicaiton {subscription.application_id} already exists." @@ -122,7 +122,7 @@ def create_subscription( ) subscription_create = SubscriptionCreate( - owner_id=subscription.owner_id, + user_id=subscription.user_id, application_id=subscription.application_id, pricing_id=subscription.pricing_id, tier=subscription.tier, diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index 1e87a6e..10d6b29 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -35,7 +35,7 @@ class ApplicationCreate(ApplicationBase): class ApplicationCreateWithOwner(ApplicationCreate): - owner_id: int + user_id: int class ApplicationDetailsWithId(ApplicationCreateWithOwner): @@ -43,7 +43,7 @@ class ApplicationDetailsWithId(ApplicationCreateWithOwner): class SubscriptionBase(BaseModel): - owner_id: int + user_id: int application_id: int tier: SubscriptionTier diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 71244c6..aeaa60f 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -45,7 +45,7 @@ class Meta: description = "description" created_at = factory.LazyFunction(datetime.now) - owner_id = factory.Sequence(int) + user_id = factory.Sequence(int) class PricingFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -88,7 +88,7 @@ class Meta: # created_by = "admin" notes = None - owner_id = factory.Sequence(int) + user_id = factory.Sequence(int) application_id = factory.Sequence(int) pricing_id = factory.Sequence(int) @@ -139,7 +139,7 @@ def api_function_2( ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - application = ApplicationFactory(id=100, name="test", url="/test", owner_id=publisher.id) + application = ApplicationFactory(id=100, name="test", url="/test", user_id=publisher.id) PricingFactory._meta.sqlalchemy_session = db_session PricingFactory._meta.sqlalchemy_session_persistence = "commit" @@ -153,7 +153,7 @@ def api_function_2( SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - SubscriptionFactory(owner_id=tester.id, application_id=application.id, credit=100, pricing=pricing) + SubscriptionFactory(user_id=tester.id, application_id=application.id, credit=100, pricing=pricing) yield TestClient(app) @@ -215,7 +215,7 @@ def test_create_and_get_subscription(self, client, db_session): ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - application = ApplicationFactory(name="application", url="/test", owner_id=publisher.id) + application = ApplicationFactory(name="application", url="/test", user_id=publisher.id) PricingFactory._meta.sqlalchemy_session = db_session PricingFactory._meta.sqlalchemy_session_persistence = "commit" @@ -227,7 +227,7 @@ def test_create_and_get_subscription(self, client, db_session): # case 1: create subscription new_subscription = SubscriptionIn( - owner_id=publisher.id, + user_id=publisher.id, application_id=application.id, pricing_id=pricing.id, tier=SubscriptionTier.TRIAL, @@ -267,7 +267,7 @@ def _require_user(): def test_create_subscription_not_existing_user(self, client): new_subscription = SubscriptionIn( - owner_id=-1, + user_id=-1, application_id=1, pricing_id=1, tier=SubscriptionTier.TRIAL, @@ -289,11 +289,11 @@ def _require_user(): SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - application = ApplicationFactory(name="app", owner_id=100) + application = ApplicationFactory(name="app", user_id=100) pricing = PricingFactory( tier=SubscriptionTier.TRIAL, price=100, credit=100, application_id=application.id ) - SubscriptionFactory(owner_id=100, application_id=application.id, pricing_id=pricing.id, credit=100) + SubscriptionFactory(user_id=100, application_id=application.id, pricing_id=pricing.id, credit=100) response = client.get( "/token/app", @@ -303,7 +303,7 @@ def _require_user(): def test_create_duplicate_subscription(self, client, db_session): new_subscription = SubscriptionIn( - owner_id=100, + user_id=100, application_id=100, pricing_id=100, tier=SubscriptionTier.TRIAL, From f25217066ce95dbc220c2995174e1eebed91f505 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 10:04:41 +0300 Subject: [PATCH 08/17] improve token handling --- apihub/security/depends.py | 35 +++++++++++++++------------------- apihub/security/helpers.py | 12 +----------- apihub/security/models.py | 7 +++---- apihub/security/router.py | 25 ++++++++---------------- apihub/security/schemas.py | 39 ++++++++++++++++++++++++++++++++++++++ tests/test_security.py | 28 +++++++++++++-------------- 6 files changed, 80 insertions(+), 66 deletions(-) diff --git a/apihub/security/depends.py b/apihub/security/depends.py index feb4e81..13156ec 100644 --- a/apihub/security/depends.py +++ b/apihub/security/depends.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, Depends, Request from fastapi_jwt_auth import AuthJWT -from .schemas import UserBaseWithId +from .schemas import UserBaseWithId, SecurityToken HTTP_429_TOO_MANY_REQUESTS = 429 @@ -52,17 +52,14 @@ def __init__(self, role: Optional[str] = None, roles: List[str] = list()): def __call__(self, Authorize: AuthJWT = Depends()): Authorize.jwt_required() - claims = Authorize.get_raw_jwt() - role = claims.get("role", "") - if role in self.roles: - name = claims.get("name", "") - email = Authorize.get_jwt_subject() - user_id = claims.get("id", "") + token = SecurityToken.from_token(Authorize) + + if token.role in self.roles: return UserBaseWithId( - id=user_id, - name=name, - email=email, - role=role, + id=token.user_id, + name=token.name, + email=token.email, + role=token.role, ) raise HTTPException( @@ -73,16 +70,14 @@ def __call__(self, Authorize: AuthJWT = Depends()): def require_token(Authorize: AuthJWT = Depends()) -> UserBaseWithId: Authorize.jwt_required() - claims = Authorize.get_raw_jwt() - role = claims.get("role", "") - name = claims.get("name", "") - email = Authorize.get_jwt_subject() - user_id = claims.get("id", "") + + token = SecurityToken.from_token(Authorize) + return UserBaseWithId( - id=user_id, - name=name, - email=email, - role=role, + id=token.user_id, + name=token.name, + email=token.email, + role=token.role, ) diff --git a/apihub/security/helpers.py b/apihub/security/helpers.py index 887e7ed..90b7682 100644 --- a/apihub/security/helpers.py +++ b/apihub/security/helpers.py @@ -24,14 +24,4 @@ def hash_password(password, salt=None): 100000, dklen=64, ).hex() - return salt, hashed_password - - -def make_token(user, expires_time): - Authorize = AuthJWT() - access_token = Authorize.create_access_token( - subject=user.email, - user_claims={"role": user.role, "name": user.name, "id": user.id}, - expires_time=expires_time, - ) - return access_token \ No newline at end of file + return salt, hashed_password \ No newline at end of file diff --git a/apihub/security/models.py b/apihub/security/models.py index e6ead66..47ce2a9 100644 --- a/apihub/security/models.py +++ b/apihub/security/models.py @@ -36,8 +36,6 @@ class User(Base): def __str__(self): return f"{self.email} || {self.role} || {self.is_active}" - - class Profile(Base): """ @@ -46,7 +44,8 @@ class Profile(Base): __tablename__ = "profiles" id = Column(Integer, primary_key=True, index=True) - name = Column(String) + first_name = Column(String) + last_name = Column(String) bio = Column(String) url = Column(String) avatar = Column(String) @@ -55,4 +54,4 @@ class Profile(Base): user = relationship("User", cascade = "all,delete", back_populates="profile") def __str__(self): - return f"{self.user_id} || {self.name}" \ No newline at end of file + return f"{self.user_id} || {self.name}" diff --git a/apihub/security/router.py b/apihub/security/router.py index 481c77c..ac5eb5a 100644 --- a/apihub/security/router.py +++ b/apihub/security/router.py @@ -6,10 +6,9 @@ from fastapi_jwt_auth import AuthJWT from ..common.db_session import create_session -from .schemas import UserCreate, UserBase, UserRegister, UserType +from .schemas import UserCreate, UserBase, UserRegister, UserType, SecurityToken from .queries import UserQuery, UserException from .depends import require_token, require_admin, require_app -from .helpers import make_token security = HTTPBasic() @@ -28,14 +27,7 @@ def get_config(): return SecuritySettings() -class AuthenticateResponse(BaseModel): - email: str - role: str - access_token: str - expires_time: int - - -@router.get("/_authenticate") +@router.get("/_authenticate", response_model=SecurityToken) async def _authenticate( credentials: HTTPBasicCredentials = Depends(security), expires_days: int = 1, @@ -54,17 +46,16 @@ async def _authenticate( if expires_days > SecuritySettings().security_token_expires_time: expires_days = SecuritySettings().security_token_expires_time - expires_time = datetime.timedelta(days=expires_days) - - Authorize = AuthJWT() - access_token = make_token(user, expires_time) - return AuthenticateResponse( + security_token = SecurityToken( email=user.email, role=user.role, - expires_time=expires_time.seconds, - access_token=access_token, + name=user.name, + user_id=user.id, + expires_days=expires_days, ) + return security_token + @router.get("/user") async def get_user( diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index d68d82a..1de6a10 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -1,10 +1,49 @@ +import datetime from enum import Enum +from typing import Optional from pydantic import BaseModel +from fastapi_jwt_auth import AuthJWT from .helpers import hash_password +class SecurityToken(BaseModel): + email: str + role: str + name: str + user_id: int + expires_days: int + access_token: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.access_token is None: + self.access_token = self.to_token() + + def to_token(self): + Authorize = AuthJWT() + expires_time = datetime.timedelta(days=self.expires_days) + access_token = Authorize.create_access_token( + subject=self.email, + user_claims={"role": self.role, "name": self.name, "user_id": self.user_id}, + expires_time=expires_time, + ) + return access_token + + @classmethod + def from_token(cls, Authorize: AuthJWT): + email = Authorize.get_jwt_subject() + claims = Authorize.get_raw_jwt() + return cls( + email=email, + role=claims["role"], + name=claims["name"], + user_id=claims["user_id"], + expires_days=0, + ) + + class UserType(str, Enum): USER = "user" PUBLISHER = "publisher" diff --git a/tests/test_security.py b/tests/test_security.py index 4d4c02a..793a3bf 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -13,8 +13,8 @@ from apihub.common.db_session import create_session from apihub.security.models import User from apihub.security.queries import UserQuery -from apihub.security.schemas import UserCreate, UserType, UserRegister -from apihub.security.router import router, AuthenticateResponse +from apihub.security.schemas import UserCreate, UserType, UserRegister, SecurityToken +from apihub.security.router import router from apihub.security.depends import require_admin from apihub.security.helpers import hash_password @@ -126,7 +126,7 @@ def test_authenticate(self, client): params={"expires_days": 2}, ) assert response.status_code == 200 - assert AuthenticateResponse.parse_obj(response.json()) + assert SecurityToken.parse_obj(response.json()).access_token is not None def test_pretected_no_token(self, client): response = client.get( @@ -140,7 +140,7 @@ def test_token(self, client): headers=self._make_auth_header("tester@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( "/protected", headers={"Authorization": f"Bearer {token}"} @@ -153,7 +153,7 @@ def test_require_admin_when_admin(self, client): headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 @@ -164,7 +164,7 @@ def test_require_admin_when_manager(self, client): headers=self._make_auth_header("publisher@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 403 @@ -175,7 +175,7 @@ def test_create_and_get_user(self, client): headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) @@ -200,7 +200,7 @@ def test_create_and_get_user(self, client): headers=self._make_auth_header("newuser@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( @@ -215,7 +215,7 @@ def test_get_users(self, client): headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( @@ -232,7 +232,7 @@ def test_list_users(self, client): headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( @@ -248,7 +248,7 @@ def test_change_password_user(self, client): headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.post( @@ -283,7 +283,7 @@ def test_change_password_admin(self, client): headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( @@ -324,7 +324,7 @@ def test_register(self, client): headers=self._make_auth_header("app@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token new_user = UserRegister( @@ -344,7 +344,7 @@ def test_register(self, client): headers=self._make_auth_header("newuser@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) From 3f0933ee3089891526c9822d631a08c69ade0bd2 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 10:37:45 +0300 Subject: [PATCH 09/17] improve subscription token handling --- apihub/server.py | 13 ++++------ apihub/subscription/depends.py | 36 +++++++------------------- apihub/subscription/models.py | 2 +- apihub/subscription/router.py | 27 ++++++++------------ apihub/subscription/schemas.py | 46 ++++++++++++++++++++++++++++++++++ tests/test_server.py | 5 ++-- tests/test_subscription.py | 8 +++--- 7 files changed, 78 insertions(+), 59 deletions(-) diff --git a/apihub/server.py b/apihub/server.py index d2afe0c..7209e4b 100644 --- a/apihub/server.py +++ b/apihub/server.py @@ -18,7 +18,7 @@ from .activity.queries import ActivityQuery from .security.depends import RateLimiter, RateLimits, require_user from .security.router import router as security_router -from .subscription.depends import require_subscription, SubscriptionResponse +from .subscription.depends import require_subscription, SubscriptionToken from .subscription.router import router as subscription_router from .utils import ( State, @@ -204,18 +204,15 @@ def fetch_result(email: str, application: str, key: str): async def async_service( request: Request, # background_tasks: BackgroundTasks, - subscription: SubscriptionResponse = Depends(require_subscription), + subscription: SubscriptionToken = Depends(require_subscription), ): """generic handler for async api.""" - email = subscription.email - tier = subscription.tier - application = subscription.application - operation_counter.labels(api=application, user=email, operation="received").inc() + operation_counter.labels(api=subscription.application, user=subscription.email, operation="received").inc() - key = await make_request(email, application, request) + key = await make_request(subscription.email, subscription.application, request) - operation_counter.labels(api=application, user=email, operation="accepted").inc() + operation_counter.labels(api=subscription.application, user=subscription.email, operation="accepted").inc() # activity = ActivityCreate( # request=f"/async/{application}", diff --git a/apihub/subscription/depends.py b/apihub/subscription/depends.py index e604037..630bf10 100644 --- a/apihub/subscription/depends.py +++ b/apihub/subscription/depends.py @@ -6,6 +6,7 @@ from ..common.db_session import create_session from ..common.redis_session import redis_conn +from .schemas import SubscriptionToken from .queries import SubscriptionQuery from .helpers import make_key, BALANCE_KEYS @@ -14,18 +15,9 @@ HTTP_429_QUOTA = 429 -class SubscriptionResponse(BaseModel): - user_id: int - subscription_id: int - application_id: int - email: str - tier: str - application: str - - def require_subscription( application: str, Authorize: AuthJWT = Depends() -) -> SubscriptionResponse: +) -> SubscriptionToken: """ This function is used to check if the user has a valid subscription token. :param application: str @@ -33,30 +25,22 @@ def require_subscription( :return: SubscriptionBase object. """ Authorize.jwt_required() - email = Authorize.get_jwt_subject() - claims = Authorize.get_raw_jwt() - subscription = claims.get("subscription") - tier = claims.get("tier") - if subscription != application: + subscription_token = SubscriptionToken.from_token(Authorize) + + if subscription_token.application != application: raise HTTPException( HTTP_403_FORBIDDEN, "The API key doesn't have permission to perform the request", ) - user_id = claims.get("use_id", -1) - subscription_id = claims.get("subscription_id", -1) - application_id = claims.get("application_id", -1) - return SubscriptionResponse( - user_id=user_id, subscription_id=subscription_id, - email=email, tier=tier, application=subscription, - application_id=application_id, - ) + + return subscription_token def require_subscription_balance( - subscription: SubscriptionResponse = Depends(require_subscription), + subscription: SubscriptionToken = Depends(require_subscription), redis: Redis = Depends(redis_conn), session=Depends(create_session), -) -> SubscriptionResponse: +) -> SubscriptionToken: """ This function is used to check if the user has enough balance to perform. :param subscription: str @@ -68,8 +52,6 @@ def require_subscription_balance( balance = redis.decr(key) - print("balance", balance) - if balance is None or balance == -1: subscription = SubscriptionQuery(session).get_subscription( subscription.subscription_id diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index 21a80d1..a0ffcd0 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -32,7 +32,7 @@ class Application(Base): created_at = Column(DateTime, default=datetime.now()) user_id = Column(Integer, ForeignKey("users.id")) - owner = relationship("User") + user = relationship("User") subscriptions = relationship("Subscription", back_populates="application") pricings = relationship("Pricing", back_populates="application") diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index 2e6276c..bb73cda 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -17,6 +17,7 @@ SubscriptionIn, ApplicationCreate, ApplicationCreateWithOwner, + SubscriptionToken, ) from .queries import ( SubscriptionQuery, @@ -179,7 +180,7 @@ class SubscriptionTokenResponse(BaseModel): expires_time: int -@router.get("/token/{application}") +@router.get("/token/{application}", response_model=SubscriptionToken) async def get_application_token( application: str, user: UserBaseWithId = Depends(require_user), @@ -211,22 +212,14 @@ async def get_application_token( if expires_days > subscription_expires_timedelta.days: expires_days = subscription_expires_timedelta.days - Authorize = AuthJWT() - expires_time = timedelta(days=expires_days) - access_token = Authorize.create_access_token( - subject=email, - user_claims={ - "subscription": application, - "tier": subscription.tier, - "user_id": user.id, - "subscription_id": subscription.id, - "application_id": subscription.application_id, - }, - expires_time=expires_time, - ) - return SubscriptionTokenResponse( + subscription_token = SubscriptionToken( email=email, + user_id=user.id, + role=user.role, application=application, - token=access_token, - expires_time=expires_time.seconds, + tier=subscription.tier, + application_id=subscription.application_id, + subscription_id=subscription.id, + expires_days=expires_days, ) + return subscription_token \ No newline at end of file diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index 10d6b29..a9ac416 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -4,6 +4,52 @@ from enum import Enum from pydantic import BaseModel +from fastapi_jwt_auth import AuthJWT + + +class SubscriptionToken(BaseModel): + user_id: int + role: str + subscription_id: int + application_id: int + email: str + tier: str + application: str + access_token: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.access_token = self.to_token() + + def to_token(self): + Authorize = AuthJWT() + access_token = Authorize.create_access_token( + subject=self.email, + user_claims={ + "role": self.role, + "user_id": self.user_id, + "subscription_id": self.subscription_id, + "application_id": self.application_id, + "tier": self.tier, + "application": self.application, + }, + ) + return access_token + + @classmethod + def from_token(cls, Authorize: AuthJWT): + email = Authorize.get_jwt_subject() + claims = Authorize.get_raw_jwt() + return cls( + email=email, + role=claims["role"], + user_id=claims["user_id"], + subscription_id=claims["subscription_id"], + application_id=claims["application_id"], + tier=claims["tier"], + application=claims["application"], + ) + class SubscriptionTier(str, Enum): TRIAL = "TRIAL" diff --git a/tests/test_server.py b/tests/test_server.py index 65bda16..a13b8fa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,7 +8,7 @@ from openapi_spec_validator import validate_spec, openapi_v30_spec_validator from apihub.common.db_session import create_session -from apihub.subscription.depends import require_subscription, SubscriptionResponse +from apihub.subscription.depends import require_subscription, SubscriptionToken from apihub.subscription.schemas import SubscriptionTier from apihub.utils import make_topic @@ -22,9 +22,10 @@ def _ip_rate_limited(): pass def _require_subscription(application:str): - return SubscriptionResponse( + return SubscriptionToken( user_id=1, subscription_id=1, application_id=1, email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", + access_token="", role="user" ) monkeypatch.setenv("OUT_KIND", "MEM") diff --git a/tests/test_subscription.py b/tests/test_subscription.py index aeaa60f..f634419 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -13,7 +13,7 @@ from apihub.security.depends import require_user, require_admin, require_token, require_publisher, require_logged_in from apihub.subscription.depends import ( require_subscription_balance, - SubscriptionResponse, + SubscriptionToken, ) from apihub.subscription.models import ( Subscription, @@ -125,7 +125,7 @@ def _create_session(): @app.get("/api_balance/{application}") def api_function_2( - application: str, subscription: SubscriptionResponse = Depends(require_subscription_balance) + application: str, subscription: SubscriptionToken = Depends(require_subscription_balance) ): pass @@ -299,7 +299,7 @@ def _require_user(): "/token/app", ) assert response.status_code == 200, response.json() - assert response.json().get("token") is not None + assert response.json().get("access_token") is not None def test_create_duplicate_subscription(self, client, db_session): new_subscription = SubscriptionIn( @@ -329,7 +329,7 @@ def _require_user(): "/token/test", ) assert response.status_code == 200, response.json() - token = response.json().get("token") + token = response.json().get("access_token") response = client.get( "/api_balance/test", headers={"Authorization": f"Bearer {token}"} From 20945d5b0a116b0f1a35f800f5d52e27efcacd00 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 11:32:14 +0300 Subject: [PATCH 10/17] cli: Add load_data to admin --- apihub/admin.py | 70 ++++++++++++++++++++++++++++---------- apihub/security/schemas.py | 10 +++++- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/apihub/admin.py b/apihub/admin.py index 7bf86b4..8221ea6 100644 --- a/apihub/admin.py +++ b/apihub/admin.py @@ -2,13 +2,14 @@ from pydantic import BaseSettings -from apihub.common.db_session import db_context, Base, DB_ENGINE +from apihub.common.db_session import db_context, Base, DB_ENGINE, create_session from apihub.common.redis_session import redis_conn -from apihub.security.schemas import UserCreate, UserType +from apihub.security.schemas import UserCreate, UserType, ProfileBase from apihub.security.queries import UserQuery from apihub.security.models import User, Profile from apihub.subscription.queries import SubscriptionQuery from apihub.subscription.models import Application, Subscription, Pricing +from apihub.subscription.schemas import ( ApplicationCreate, SubscriptionCreate, ) class SuperUser(BaseSettings): @@ -54,22 +55,55 @@ def deinit(): def load_data(filename): import yaml - data = yaml.load(open(filename, 'r', encoding='utf-8'),) - for name, items in data.items(): - if name == 'user': - with db_context() as session: - query = UserQuery(session) + data = yaml.safe_load(open(filename, 'r', encoding='utf-8'),) + with db_context() as session: + users = {} + applications = {} + pricings = {} + subscriptions = {} + for name, items in data.items(): + if name == 'user': + for item in items: + profile_data = ProfileBase(**item) + name = profile_data.first_name + " " + profile_data.last_name + user_data = UserCreate(name=name, **item) + user = User(**user_data.make_user().dict()) + session.add(user) + + profile = Profile( + user_id=user.id, + **profile_data.dict(), + ) + session.add(profile) + users[user.email] = user + elif name == 'application': for item in items: - user_id = query.create_user( - UserCreate( - name=item.name, - email=item.email, - password=item.password, - role=item.role, + application_data = ApplicationCreate(**item) + pricings = [] + for pricing in item['pricings']: + pricing = Pricing( + tier=pricing['tier'], + credit=pricing['credit'], + price=pricing['price'], ) + pricings.append(pricing) + application = Application( + name=application_data.name, + url=application_data.url, + description=application_data.description, + user_id=users[item['user']].id, + pricings=pricings, ) - if user_id: - elif name == 'application': - for item in items: - else: - print(f"model {name} not supported", file=sys.stderr) + session.add(application) + applications[application.name] = application + # elif name == 'subscription': + # for item in items: + # subscription = Subscription( + # user_id=users[item.user].id, + # application_id=applications[item.application].id, + # pricing_id=pricings[item.pricing].id, + # expires_at=item.expires_at, + # ) + # subscriptions[subscription.user_id] = subscription + else: + print(f"model {name} not supported", file=sys.stderr) diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index 1de6a10..a4c09a5 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -113,4 +113,12 @@ class UserSession(UserBase): class User(UserSession): - pass \ No newline at end of file + pass + + +class ProfileBase(BaseModel): + first_name: str + last_name: str + bio: Optional[str] + url: Optional[str] + avatar: Optional[str] \ No newline at end of file From f09afdef1c9fe05eed490f5ecb18fa8342f5c708 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 13:38:31 +0300 Subject: [PATCH 11/17] Update token --- apihub/activity/middlewares.py | 0 apihub/activity/schemas.py | 10 +++++++++- apihub/security/schemas.py | 6 ++++-- apihub/subscription/router.py | 1 + apihub/subscription/schemas.py | 11 ++++++----- tests/test_server.py | 2 +- 6 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 apihub/activity/middlewares.py diff --git a/apihub/activity/middlewares.py b/apihub/activity/middlewares.py new file mode 100644 index 0000000..e69de29 diff --git a/apihub/activity/schemas.py b/apihub/activity/schemas.py index 8199d86..b809c52 100644 --- a/apihub/activity/schemas.py +++ b/apihub/activity/schemas.py @@ -5,7 +5,15 @@ class ActivityBase(BaseModel): - pass + ip: Optional[str] = None + path: str + method: str + user_id: int + application_id: int + subscription_id: int + request_key: Optional[str] = None + payload: Optional[str] = None + response: Optional[str] = None class ActivityCreate(ActivityBase): diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index a4c09a5..eda1d93 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -13,7 +13,7 @@ class SecurityToken(BaseModel): role: str name: str user_id: int - expires_days: int + expires_days: Optional[int] = None access_token: Optional[str] = None def __init__(self, **kwargs): @@ -26,7 +26,9 @@ def to_token(self): expires_time = datetime.timedelta(days=self.expires_days) access_token = Authorize.create_access_token( subject=self.email, - user_claims={"role": self.role, "name": self.name, "user_id": self.user_id}, + user_claims={ + "role": self.role, "name": self.name, "user_id": self.user_id, + }, expires_time=expires_time, ) return access_token diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index bb73cda..929089b 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -214,6 +214,7 @@ async def get_application_token( subscription_token = SubscriptionToken( email=email, + name = user.name, user_id=user.id, role=user.role, application=application, diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index a9ac416..1f5149d 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -3,16 +3,14 @@ from enum import Enum from pydantic import BaseModel - from fastapi_jwt_auth import AuthJWT +from ..security.schemas import ( SecurityToken ) -class SubscriptionToken(BaseModel): - user_id: int - role: str + +class SubscriptionToken(SecurityToken): subscription_id: int application_id: int - email: str tier: str application: str access_token: Optional[str] = None @@ -26,12 +24,14 @@ def to_token(self): access_token = Authorize.create_access_token( subject=self.email, user_claims={ + "name": self.name, "role": self.role, "user_id": self.user_id, "subscription_id": self.subscription_id, "application_id": self.application_id, "tier": self.tier, "application": self.application, + "expires_days": self.expires_days, }, ) return access_token @@ -42,6 +42,7 @@ def from_token(cls, Authorize: AuthJWT): claims = Authorize.get_raw_jwt() return cls( email=email, + name=claims["name"], role=claims["role"], user_id=claims["user_id"], subscription_id=claims["subscription_id"], diff --git a/tests/test_server.py b/tests/test_server.py index a13b8fa..21cef3d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -25,7 +25,7 @@ def _require_subscription(application:str): return SubscriptionToken( user_id=1, subscription_id=1, application_id=1, email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", - access_token="", role="user" + access_token="", role="user", name="user", expires_days=1, ) monkeypatch.setenv("OUT_KIND", "MEM") From 7bb68b50d6da216d2c9f7253b72739dc86f15fa0 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 20:12:18 +0300 Subject: [PATCH 12/17] Improve server tests --- tests/test_server.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 21cef3d..92da18e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,20 +21,12 @@ def _create_session(): def _ip_rate_limited(): pass - def _require_subscription(application:str): - return SubscriptionToken( - user_id=1, subscription_id=1, application_id=1, - email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", - access_token="", role="user", name="user", expires_days=1, - ) monkeypatch.setenv("OUT_KIND", "MEM") from apihub.server import api, ip_rate_limited api.dependency_overrides[ip_rate_limited] = _ip_rate_limited - api.dependency_overrides[require_subscription] = _require_subscription - api.dependency_overrides[create_session] = _create_session yield TestClient(api) @@ -73,9 +65,15 @@ def get(self, application): monkeypatch.setattr( apihub.server, "get_definition_manager", _get_definition_manager ) + token = SubscriptionToken( + user_id=1, subscription_id=1, application_id=1, + email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", + role="user", name="user", expires_days=1, + ) response = client.post( - "/async/test", params={"text": "this is simple"}, json={"probability": 0.6} + "/async/test", params={"text": "this is simple"}, json={"probability": 0.6}, + headers={"Authorization": f"Bearer {token.access_token}"} ) assert response.status_code == 200 @@ -127,8 +125,15 @@ def get(self, application): apihub.server, "get_definition_manager", _get_definition_manager ) + token = SubscriptionToken( + user_id=1, subscription_id=1, application_id=1, + email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", + role="user", name="user", expires_days=1, + ) + response = client.post( - "/async/test", params={}, json={"probability": 0.6} + "/async/test", params={}, json={"probability": 0.6}, + headers={"Authorization": f"Bearer {token.access_token}"} ) assert response.status_code == 422 From 21b54a284d4b51e436944cfffcf36903e28c4e00 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 20:13:09 +0300 Subject: [PATCH 13/17] Add ActivityLogger middleware --- apihub/activity/middlewares.py | 111 +++++++++++++++++++++++++++++++++ apihub/activity/models.py | 24 +++---- apihub/activity/queries.py | 90 -------------------------- apihub/activity/schemas.py | 26 ++------ apihub/admin.py | 6 +- apihub/server.py | 32 +++------- tests/conftest.py | 3 + 7 files changed, 142 insertions(+), 150 deletions(-) delete mode 100644 apihub/activity/queries.py diff --git a/apihub/activity/middlewares.py b/apihub/activity/middlewares.py index e69de29..2078ab2 100644 --- a/apihub/activity/middlewares.py +++ b/apihub/activity/middlewares.py @@ -0,0 +1,111 @@ +import json +from typing import Callable, Any +from fastapi import Request +from fastapi_jwt_auth import AuthJWT +from starlette.middleware.base import BaseHTTPMiddleware + +from ..common.db_session import db_context +from ..security.schemas import SecurityToken +from .schemas import ActivityBase +from .models import Activity + +class ActivityLogger(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + + async def set_body(self, request: Request): + receive_ = await request._receive() + + async def receive(): + return receive_ + + request._receive = receive + + async def dispatch(self, request: Request, call_next): + data = { + "ip": request.client.host, + "user_agent": request.headers.get("User-Agent"), + "method": request.method, + "path": request.url.path, + "headers": dict(request.headers), + "query_params": request.query_params, + } + + is_recording = request.url.path.startswith("/async") + + if is_recording: + await self.set_body(request) + body = await request.body() + if body: + data["request_body"] = body + + # get authorization from request + authorization = request.headers.get('Authorization') + if authorization: + auth = AuthJWT(req=request) + token = SecurityToken.from_token(auth) + data["user_id"] = token.user_id + + # call next middleware + response = await call_next(request) + + if is_recording: + # extract response body + data["response_status_code"] = response.status_code + try: + data["response_body"] = json.dumps(await response.json(), encoding='utf-8') + except Exception as e: + pass + + try: + with db_context() as session: + activity = ActivityBase(**data) + session.add(Activity(**activity.dict())) + except Exception as e: + pass + + return response + + +async def log_activity(request: Request, call_next: Callable, session: Any): + data = { + "ip": request.client.host, + "user_agent": request.headers.get("User-Agent"), + "method": request.method, + "path": request.url.path, + "headers": dict(request.headers), + "query_params": request.query_params, + } + + is_recording = request.url.path.startswith("/async") + + if is_recording: + try: + data["request_body"] = json.dumps(await request.json(), encoding='utf-8') + except Exception as e: + pass + + # get authorization from request + authorization = request.headers.get('Authorization') + if authorization: + auth = AuthJWT(req=request) + token = SecurityToken.from_token(auth) + data["user_id"] = token.user_id + + # call next middleware + response = await call_next(request) + + if is_recording: + # extract response body + data["response_status_code"] = response.status_code + try: + data["response_body"] = json.dumps(await response.json(), encoding='utf-8') + except Exception as e: + pass + + # store activity + with db_context() as session: + activity = ActivityBase(**data) + session.add(Activity(**activity.dict())) + + return response \ No newline at end of file diff --git a/apihub/activity/models.py b/apihub/activity/models.py index 54ca666..f07772a 100644 --- a/apihub/activity/models.py +++ b/apihub/activity/models.py @@ -8,23 +8,17 @@ class Activity(Base): - """ - This class is used to store activity data. - """ - - __tablename__ = "activity" + __tablename__ = "activities" id = Column(Integer, primary_key=True, index=True) created_at = Column(DateTime, default=datetime.now()) - request = Column(String) - username = Column(String) - tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) - status = Column(Enum(ActivityStatus), default=ActivityStatus.ACCEPTED) - request_key = Column(String) - result = Column(String) - payload = Column(String) - ip_address = Column(String) - latency = Column(Float) + ip = Column(String) + path = Column(String) + method = Column(String) + user_id = Column(Integer, default=-1) + request_body = Column(String) + response_status_code = Column(String) + response_body = Column(String) def __str__(self): - return f"{self.request} || {self.username}" + return f"{self.ip} || {self.path} || {self.method} || {self.user_id}" \ No newline at end of file diff --git a/apihub/activity/queries.py b/apihub/activity/queries.py deleted file mode 100644 index ff0637a..0000000 --- a/apihub/activity/queries.py +++ /dev/null @@ -1,90 +0,0 @@ -import datetime - -from sqlalchemy.orm import Query -from sqlalchemy.exc import IntegrityError, NoResultFound - -from ..common.queries import BaseQuery - -from .models import Activity -from .schemas import ActivityCreate, ActivityDetails - - -class ActivityException(Exception): - pass - - -class ActivityQuery(BaseQuery): - def get_query(self) -> Query: - """ - Get query object - :return: Query object. - """ - return self.session.query(Activity) - - def create_activity(self, activity: ActivityCreate) -> None: - """ - Create a new activity. - :param activity_create: ActivityCreate object. - :return: None - """ - activity = Activity(**activity.dict()) - self.session.add(activity) - try: - self.session.commit() - except IntegrityError: - self.session.rollback() - raise ActivityException("IntegrityError") - - def get_activity_by_key(self, request_key: str) -> ActivityDetails: - """ - Get activity by request key. - :param request_key: str - :return: ActivityDetails object. - """ - try: - activity = ( - self.get_query().filter(Activity.request_key == request_key).one() - ) - return ActivityDetails( - created_at=activity.created_at, - request=activity.request, - tier=activity.tier, - status=activity.status, - request_key=activity.request_key, - result=activity.result, - payload=activity.payload, - ip_address=activity.ip_address, - latency=activity.latency, - ) - except NoResultFound: - raise ActivityException - - def update_activity(self, request_key, set_latency=True, **kwargs) -> None: - """ - Update activity by request key. - :param request_key: str - :param set_latency: bool - :param kwargs: dict of fields to update. - :return: None - """ - try: - activity = ( - self.get_query().filter(Activity.request_key == request_key).one() - ) - if set_latency: - kwargs["latency"] = ( - datetime.datetime.now() - activity.created_at - ).total_seconds() - except NoResultFound: - raise ActivityException - - for key, value in kwargs.items(): - setattr(activity, key, value) - - self.session.add(activity) - try: - self.session.commit() - self.session.refresh(activity) - except IntegrityError: - self.session.rollback() - raise ActivityException("IntegrityError") diff --git a/apihub/activity/schemas.py b/apihub/activity/schemas.py index b809c52..9100c2b 100644 --- a/apihub/activity/schemas.py +++ b/apihub/activity/schemas.py @@ -8,30 +8,16 @@ class ActivityBase(BaseModel): ip: Optional[str] = None path: str method: str - user_id: int - application_id: int - subscription_id: int - request_key: Optional[str] = None - payload: Optional[str] = None - response: Optional[str] = None + user_id: int = -1 + request_body: Optional[str] = None + response_status_code: Optional[str] = None + response_body: Optional[str] = None -class ActivityCreate(ActivityBase): - request: str - username: Optional[str] = None - tier: str - status: str - request_key: Optional[str] = None - result: Optional[str] = None - payload: Optional[str] = None - ip_address: Optional[str] = None - latency: Optional[float] = None - - -class ActivityDetails(ActivityCreate): +class ActivityDetails(ActivityBase): created_at: datetime class ActivityStatus(str, Enum): ACCEPTED = "ACCEPTED" - PROCESSED = "PROCESSED" + PROCESSED = "PROCESSED" \ No newline at end of file diff --git a/apihub/admin.py b/apihub/admin.py index 8221ea6..542d36a 100644 --- a/apihub/admin.py +++ b/apihub/admin.py @@ -6,10 +6,11 @@ from apihub.common.redis_session import redis_conn from apihub.security.schemas import UserCreate, UserType, ProfileBase from apihub.security.queries import UserQuery -from apihub.security.models import User, Profile +from apihub.security.models import * from apihub.subscription.queries import SubscriptionQuery -from apihub.subscription.models import Application, Subscription, Pricing +from apihub.subscription.models import * from apihub.subscription.schemas import ( ApplicationCreate, SubscriptionCreate, ) +from apihub.activity.models import * class SuperUser(BaseSettings): @@ -29,6 +30,7 @@ def as_usercreate(self): def init(): Base.metadata.bind = DB_ENGINE Base.metadata.create_all() + print("\n".join(Base.metadata.tables.keys()), file=sys.stderr) with db_context() as session: user = SuperUser().as_usercreate() diff --git a/apihub/server.py b/apihub/server.py index 7209e4b..716b4ce 100644 --- a/apihub/server.py +++ b/apihub/server.py @@ -1,9 +1,11 @@ import sys import functools +from functools import partial import logging -from typing import Dict, Any +from typing import Coroutine, Dict, Any -from fastapi import FastAPI, HTTPException, Request, Query, Depends, BackgroundTasks +from fastapi import FastAPI, HTTPException, Request, Query, Depends +from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException @@ -13,9 +15,9 @@ from dotenv import load_dotenv from pipeline import Message, Settings, Command, CommandActions, Monitor -from .common.db_session import db_context -from .activity.schemas import ActivityStatus, ActivityCreate -from .activity.queries import ActivityQuery +from .common.db_session import create_session +from .activity.schemas import ActivityStatus +from .activity.middlewares import ActivityLogger from .security.depends import RateLimiter, RateLimits, require_user from .security.router import router as security_router from .subscription.depends import require_subscription, SubscriptionToken @@ -85,6 +87,8 @@ def jwt_get_config(): subscription_router, tags=["subscription"], dependencies=[Depends(ip_rate_limited)] ) +api.add_middleware(ActivityLogger) + @api.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): @@ -203,7 +207,6 @@ def fetch_result(email: str, application: str, key: str): ) async def async_service( request: Request, - # background_tasks: BackgroundTasks, subscription: SubscriptionToken = Depends(require_subscription), ): """generic handler for async api.""" @@ -214,23 +217,6 @@ async def async_service( operation_counter.labels(api=subscription.application, user=subscription.email, operation="accepted").inc() - # activity = ActivityCreate( - # request=f"/async/{application}", - # username=username, - # tier=tier, - # status=ActivityStatus.ACCEPTED, - # request_key=str(key), - # result=str(info.dict()), - # payload=str(dct), - # ip_address=str(request.client.host), - # latency=0.0, - # ) - # - # def add_activity_task(activity): - # with db_context() as session: - # ActivityQuery(session).create_activity(activity) - # - # background_tasks.add_task(add_activity_task, activity=activity) return AsyncAPIRequestResponse(success=True, key=key) diff --git a/tests/conftest.py b/tests/conftest.py index 609b637..5bc5336 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ create_database, drop_database, ) +from apihub.security.models import * +from apihub.subscription.models import * +from apihub.activity.models import * DB_ENGINE = get_db_engine() From 36a66f429148269d9095ab8fd55714f1258e12a2 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 20:14:12 +0300 Subject: [PATCH 14/17] Add pyyaml for admin script --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 446a60b..2874670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ apihub_admin = "apihub.admin:create_all_statements" [tool.poetry.group.dev.dependencies] openapi-spec-validator = "^0.5.1" alembic = "^1.9.1" +pyyaml = "^6.0" [tool.black] line-length = 88 From ddab341228f6d74b92f8939b95811618701f862b Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 20:14:58 +0300 Subject: [PATCH 15/17] Update lock --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 62c8659..cd1bedb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2462,4 +2462,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "f217b8aa99c8e57d4f8a77613ff54337e5f60037383521536d00892da025368e" +content-hash = "99aa50f05c3b71e5387e8b5084f38f2ca49305473c8726b5289f79b3729e95c3" From 88160c9f706edfab7cb25d507c47abe578038344 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Wed, 25 Jan 2023 21:26:31 +0300 Subject: [PATCH 16/17] git: ignore prod --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 03e00e0..71a2baa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ dist **.apihub **/.DS_Store .secrets +prod.env +prod From 1b5f4a394c29420181aa4df2d6f30285b75a21a6 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Thu, 26 Jan 2023 14:09:57 +0300 Subject: [PATCH 17/17] Add path to Application model --- apihub/subscription/models.py | 5 +++++ apihub/subscription/queries.py | 12 ++++++------ apihub/subscription/router.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index a0ffcd0..1c50b66 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -16,6 +16,10 @@ from .schemas import SubscriptionTier, ApplicationCreate, PricingCreate +def set_default_path(context): + return context.get_current_parameters()["name"].lower().replace(" ", "-") + + class Application(Base): """ This class is used to store application data. @@ -25,6 +29,7 @@ class Application(Base): id = Column(Integer, primary_key=True, index=True) name = Column(String, unique=True, index=True, nullable=False) + path = Column(String, unique=True, index=True, nullable=False, default=set_default_path) url = Column(String) description = Column(String) is_active = Column(Boolean, default=True) diff --git a/apihub/subscription/queries.py b/apihub/subscription/queries.py index 7656f09..fbd87fb 100644 --- a/apihub/subscription/queries.py +++ b/apihub/subscription/queries.py @@ -80,17 +80,17 @@ def get_application(self, application_id: int) -> ApplicationCreate: except NoResultFound: raise ApplicationException(f"Application {application_id} not found.") - def get_application_by_name(self, name: str) -> ApplicationCreate: + def get_application_by_path(self, path: str) -> ApplicationCreate: """ Get application by name. :param name: Application name. :return: application object. """ try: - application = self.get_query().filter(Application.name == name).one() + application = self.get_query().filter(Application.path == path).one() return application.to_schema(with_pricing=True) except NoResultFound: - raise ApplicationException(f"Application {name} not found.") + raise ApplicationException(f"Application with path {path} not found.") def get_applications(self, email=None) -> List[ApplicationCreate]: """ @@ -211,8 +211,8 @@ def create_subscription(self, subscription_create: SubscriptionCreate): self.session.rollback() raise SubscriptionException(f"Error creating subscription: {e}") - def get_active_subscription_by_name( - self, user_id: int, application: str + def get_active_subscription_by_path( + self, user_id: int, path: str ) -> SubscriptionDetails: """ Get active subscription of a user. @@ -222,7 +222,7 @@ def get_active_subscription_by_name( """ try: application = self.session.query(Application).filter( - Application.name == application + Application.path == path ).one() subscription = self.get_query().filter( diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index 929089b..e66ef04 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -74,9 +74,9 @@ def get_applications( raise HTTPException(400, detail=str(e)) -@router.get("/application/{application}", response_model=ApplicationCreate) +@router.get("/application/{path}", response_model=ApplicationCreate) def get_application( - application: str, + path: str, session: Session = Depends(create_session), user: str = Depends(require_logged_in), ): @@ -84,9 +84,9 @@ def get_application( """ Get an application. """ - return ApplicationQuery(session).get_application_by_name(application) + return ApplicationQuery(session).get_application_by_path(path) except ApplicationException: - raise HTTPException(400, f"Error while retrieving application {application}") + raise HTTPException(400, f"Error while retrieving application with path {path}") @router.post("/subscription") @@ -180,9 +180,9 @@ class SubscriptionTokenResponse(BaseModel): expires_time: int -@router.get("/token/{application}", response_model=SubscriptionToken) +@router.get("/token/{path}", response_model=SubscriptionToken) async def get_application_token( - application: str, + path: str, user: UserBaseWithId = Depends(require_user), email: Optional[str] = None, expires_days: Optional[ @@ -200,7 +200,7 @@ async def get_application_token( raise HTTPException(401, "email is missing") try: - subscription = query.get_active_subscription_by_name(user.id, application) + subscription = query.get_active_subscription_by_path(user.id, path) except SubscriptionException: raise HTTPException(401, f"No active subscription found for user {email}") @@ -217,7 +217,7 @@ async def get_application_token( name = user.name, user_id=user.id, role=user.role, - application=application, + application=path, tier=subscription.tier, application_id=subscription.application_id, subscription_id=subscription.id,