diff --git a/.gitignore b/.gitignore index 03e00e0..71a2baa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ dist **.apihub **/.DS_Store .secrets +prod.env +prod diff --git a/9be9edae04c5_add_owner_and_created_at_to_application_.py b/9be9edae04c5_add_owner_and_created_at_to_application_.py new file mode 100644 index 0000000..c859341 --- /dev/null +++ b/9be9edae04c5_add_owner_and_created_at_to_application_.py @@ -0,0 +1,26 @@ +"""add owner and created_at to application table + +Revision ID: 9be9edae04c5 +Revises: +Create Date: 2023-01-12 15:01:46.535112 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9be9edae04c5' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('application', sa.Column("created_at", sa.DateTime, nullable=False)) + op.add_column('application', sa.Column("owner", sa.ForeignKey('user.username'), nullable=False)) + + +def downgrade() -> None: + op.drop_column('application', "owner") + op.drop_column('application', "created_at") \ No newline at end of file diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..86d1800 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..f96cf10 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,81 @@ +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +config.set_main_option('sqlalchemy.url', os.environ.get('DB_URI').replace('%', '%%')) + +# add your model's MetaData object here +# for 'autogenerate' support +import apihub.server +from apihub.common.db_session import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..55df286 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/5012a4422d71_add_owner_to_application_table.py b/alembic/versions/5012a4422d71_add_owner_to_application_table.py new file mode 100644 index 0000000..77c6fc5 --- /dev/null +++ b/alembic/versions/5012a4422d71_add_owner_to_application_table.py @@ -0,0 +1,27 @@ +"""Add owner to application table + +Revision ID: 5012a4422d71 +Revises: +Create Date: 2023-01-15 10:55:32.594417 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5012a4422d71' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('application', sa.Column('created_at', sa.DateTime(), nullable=True)) + op.add_column('application', sa.Column('owner', sa.String(), nullable=True)) + op.create_foreign_key(None, 'application', 'users', ['owner'], ['username']) + +def downgrade() -> None: + op.drop_constraint(None, 'application', type_='foreignkey') + op.drop_column('application', 'owner') + op.drop_column('application', 'created_at') diff --git a/apihub/activity/middlewares.py b/apihub/activity/middlewares.py new file mode 100644 index 0000000..2078ab2 --- /dev/null +++ b/apihub/activity/middlewares.py @@ -0,0 +1,111 @@ +import json +from typing import Callable, Any +from fastapi import Request +from fastapi_jwt_auth import AuthJWT +from starlette.middleware.base import BaseHTTPMiddleware + +from ..common.db_session import db_context +from ..security.schemas import SecurityToken +from .schemas import ActivityBase +from .models import Activity + +class ActivityLogger(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + + async def set_body(self, request: Request): + receive_ = await request._receive() + + async def receive(): + return receive_ + + request._receive = receive + + async def dispatch(self, request: Request, call_next): + data = { + "ip": request.client.host, + "user_agent": request.headers.get("User-Agent"), + "method": request.method, + "path": request.url.path, + "headers": dict(request.headers), + "query_params": request.query_params, + } + + is_recording = request.url.path.startswith("/async") + + if is_recording: + await self.set_body(request) + body = await request.body() + if body: + data["request_body"] = body + + # get authorization from request + authorization = request.headers.get('Authorization') + if authorization: + auth = AuthJWT(req=request) + token = SecurityToken.from_token(auth) + data["user_id"] = token.user_id + + # call next middleware + response = await call_next(request) + + if is_recording: + # extract response body + data["response_status_code"] = response.status_code + try: + data["response_body"] = json.dumps(await response.json(), encoding='utf-8') + except Exception as e: + pass + + try: + with db_context() as session: + activity = ActivityBase(**data) + session.add(Activity(**activity.dict())) + except Exception as e: + pass + + return response + + +async def log_activity(request: Request, call_next: Callable, session: Any): + data = { + "ip": request.client.host, + "user_agent": request.headers.get("User-Agent"), + "method": request.method, + "path": request.url.path, + "headers": dict(request.headers), + "query_params": request.query_params, + } + + is_recording = request.url.path.startswith("/async") + + if is_recording: + try: + data["request_body"] = json.dumps(await request.json(), encoding='utf-8') + except Exception as e: + pass + + # get authorization from request + authorization = request.headers.get('Authorization') + if authorization: + auth = AuthJWT(req=request) + token = SecurityToken.from_token(auth) + data["user_id"] = token.user_id + + # call next middleware + response = await call_next(request) + + if is_recording: + # extract response body + data["response_status_code"] = response.status_code + try: + data["response_body"] = json.dumps(await response.json(), encoding='utf-8') + except Exception as e: + pass + + # store activity + with db_context() as session: + activity = ActivityBase(**data) + session.add(Activity(**activity.dict())) + + return response \ No newline at end of file diff --git a/apihub/activity/models.py b/apihub/activity/models.py index 54ca666..f07772a 100644 --- a/apihub/activity/models.py +++ b/apihub/activity/models.py @@ -8,23 +8,17 @@ class Activity(Base): - """ - This class is used to store activity data. - """ - - __tablename__ = "activity" + __tablename__ = "activities" id = Column(Integer, primary_key=True, index=True) created_at = Column(DateTime, default=datetime.now()) - request = Column(String) - username = Column(String) - tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) - status = Column(Enum(ActivityStatus), default=ActivityStatus.ACCEPTED) - request_key = Column(String) - result = Column(String) - payload = Column(String) - ip_address = Column(String) - latency = Column(Float) + ip = Column(String) + path = Column(String) + method = Column(String) + user_id = Column(Integer, default=-1) + request_body = Column(String) + response_status_code = Column(String) + response_body = Column(String) def __str__(self): - return f"{self.request} || {self.username}" + return f"{self.ip} || {self.path} || {self.method} || {self.user_id}" \ No newline at end of file diff --git a/apihub/activity/queries.py b/apihub/activity/queries.py deleted file mode 100644 index ff0637a..0000000 --- a/apihub/activity/queries.py +++ /dev/null @@ -1,90 +0,0 @@ -import datetime - -from sqlalchemy.orm import Query -from sqlalchemy.exc import IntegrityError, NoResultFound - -from ..common.queries import BaseQuery - -from .models import Activity -from .schemas import ActivityCreate, ActivityDetails - - -class ActivityException(Exception): - pass - - -class ActivityQuery(BaseQuery): - def get_query(self) -> Query: - """ - Get query object - :return: Query object. - """ - return self.session.query(Activity) - - def create_activity(self, activity: ActivityCreate) -> None: - """ - Create a new activity. - :param activity_create: ActivityCreate object. - :return: None - """ - activity = Activity(**activity.dict()) - self.session.add(activity) - try: - self.session.commit() - except IntegrityError: - self.session.rollback() - raise ActivityException("IntegrityError") - - def get_activity_by_key(self, request_key: str) -> ActivityDetails: - """ - Get activity by request key. - :param request_key: str - :return: ActivityDetails object. - """ - try: - activity = ( - self.get_query().filter(Activity.request_key == request_key).one() - ) - return ActivityDetails( - created_at=activity.created_at, - request=activity.request, - tier=activity.tier, - status=activity.status, - request_key=activity.request_key, - result=activity.result, - payload=activity.payload, - ip_address=activity.ip_address, - latency=activity.latency, - ) - except NoResultFound: - raise ActivityException - - def update_activity(self, request_key, set_latency=True, **kwargs) -> None: - """ - Update activity by request key. - :param request_key: str - :param set_latency: bool - :param kwargs: dict of fields to update. - :return: None - """ - try: - activity = ( - self.get_query().filter(Activity.request_key == request_key).one() - ) - if set_latency: - kwargs["latency"] = ( - datetime.datetime.now() - activity.created_at - ).total_seconds() - except NoResultFound: - raise ActivityException - - for key, value in kwargs.items(): - setattr(activity, key, value) - - self.session.add(activity) - try: - self.session.commit() - self.session.refresh(activity) - except IntegrityError: - self.session.rollback() - raise ActivityException("IntegrityError") diff --git a/apihub/activity/schemas.py b/apihub/activity/schemas.py index 8199d86..9100c2b 100644 --- a/apihub/activity/schemas.py +++ b/apihub/activity/schemas.py @@ -5,25 +5,19 @@ class ActivityBase(BaseModel): - pass + ip: Optional[str] = None + path: str + method: str + user_id: int = -1 + request_body: Optional[str] = None + response_status_code: Optional[str] = None + response_body: Optional[str] = None -class ActivityCreate(ActivityBase): - request: str - username: Optional[str] = None - tier: str - status: str - request_key: Optional[str] = None - result: Optional[str] = None - payload: Optional[str] = None - ip_address: Optional[str] = None - latency: Optional[float] = None - - -class ActivityDetails(ActivityCreate): +class ActivityDetails(ActivityBase): created_at: datetime class ActivityStatus(str, Enum): ACCEPTED = "ACCEPTED" - PROCESSED = "PROCESSED" + PROCESSED = "PROCESSED" \ No newline at end of file diff --git a/apihub/admin.py b/apihub/admin.py new file mode 100644 index 0000000..542d36a --- /dev/null +++ b/apihub/admin.py @@ -0,0 +1,111 @@ +import sys + +from pydantic import BaseSettings + +from apihub.common.db_session import db_context, Base, DB_ENGINE, create_session +from apihub.common.redis_session import redis_conn +from apihub.security.schemas import UserCreate, UserType, ProfileBase +from apihub.security.queries import UserQuery +from apihub.security.models import * +from apihub.subscription.queries import SubscriptionQuery +from apihub.subscription.models import * +from apihub.subscription.schemas import ( ApplicationCreate, SubscriptionCreate, ) +from apihub.activity.models import * + + +class SuperUser(BaseSettings): + name: str + password: str + email: str + + def as_usercreate(self): + return UserCreate( + name=self.name, + password=self.password, + email=self.email, + role=UserType.ADMIN, + ) + + +def init(): + Base.metadata.bind = DB_ENGINE + Base.metadata.create_all() + print("\n".join(Base.metadata.tables.keys()), file=sys.stderr) + + with db_context() as session: + user = SuperUser().as_usercreate() + UserQuery(session).create_user(user) + print(f"Admin {user.name} is created!", file=sys.stderr) + + +def create_all_statements(): + from sqlalchemy import create_mock_engine + + def metadata_dump(sql, *multiparams, **params): + print(sql.compile(dialect=engine.dialect)) + + engine = create_mock_engine("postgresql://", metadata_dump) + Base.metadata.bind = engine + Base.metadata.create_all(engine, checkfirst=False) + + +def deinit(): + Base.metadata.bind = DB_ENGINE + Base.metadata.drop_all() + print("deinit is done!", file=sys.stderr) + + +def load_data(filename): + import yaml + data = yaml.safe_load(open(filename, 'r', encoding='utf-8'),) + with db_context() as session: + users = {} + applications = {} + pricings = {} + subscriptions = {} + for name, items in data.items(): + if name == 'user': + for item in items: + profile_data = ProfileBase(**item) + name = profile_data.first_name + " " + profile_data.last_name + user_data = UserCreate(name=name, **item) + user = User(**user_data.make_user().dict()) + session.add(user) + + profile = Profile( + user_id=user.id, + **profile_data.dict(), + ) + session.add(profile) + users[user.email] = user + elif name == 'application': + for item in items: + application_data = ApplicationCreate(**item) + pricings = [] + for pricing in item['pricings']: + pricing = Pricing( + tier=pricing['tier'], + credit=pricing['credit'], + price=pricing['price'], + ) + pricings.append(pricing) + application = Application( + name=application_data.name, + url=application_data.url, + description=application_data.description, + user_id=users[item['user']].id, + pricings=pricings, + ) + session.add(application) + applications[application.name] = application + # elif name == 'subscription': + # for item in items: + # subscription = Subscription( + # user_id=users[item.user].id, + # application_id=applications[item.application].id, + # pricing_id=pricings[item.pricing].id, + # expires_at=item.expires_at, + # ) + # subscriptions[subscription.user_id] = subscription + else: + print(f"model {name} not supported", file=sys.stderr) diff --git a/apihub/cli.py b/apihub/cli.py deleted file mode 100644 index 5471d24..0000000 --- a/apihub/cli.py +++ /dev/null @@ -1,226 +0,0 @@ -import sys -import json -import time -import asyncio -import fileinput -from datetime import datetime, timedelta - -try: - import typer -except ImportError: - sys.stderr.write('Please install "apihub[cli] to install cli option"') - sys.exit(1) - -from dotenv import load_dotenv - -load_dotenv() - -from apihub.client import Client -from .subscription.schemas import SubscriptionIn - - -cli = typer.Typer() - - -def try_load_state(username): - client = Client.load_state(filename=f"{username}.apihub") - if client.token is None: - typer.echo("You need to login first") - sys.exit(1) - return client - - -@cli.command() -def login(username: str, password: str, endpoint: str = "http://localhost:5000"): - client = Client({"endpoint": endpoint}) - client.authenticate(username=username, password=password) - client.save_state(filename=f"{username}.apihub") - - -@cli.command() -def refresh_token( - application: str, - admin: str = "", - manager: str = "", - username: str = "", - expires: int = 1, -): - if admin: - admin_client = Client.load_state(filename=f"{admin}.apihub") - elif manager: - admin_client = Client.load_state(filename=f"{manager}.apihub") - else: - admin_client = None - client = Client.load_state(filename=f"{username}.apihub") - - if admin_client: - if admin_client.token is None: - cli.echo("You need to login first") - sys.exit(1) - token = admin_client.refresh_application_token(application, username, expires) - client.applications[application] = token - else: - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - client.refresh_application_token(application, username, expires) - client.save_state(filename=f"{username}.apihub") - - -@cli.command() -def list_users(role: str, admin: str = ""): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - users = client.get_users_by_role(role) - for user in users: - print(user) - - -@cli.command() -def create_user( - username: str, password: str, email: str, admin: str = "", role: str = "user" -): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - client.create_user( - { - "username": username, - "password": password, - "role": role, - "email": email, - } - ) - - -@cli.command() -def create_subscription( - username: str, - application: str, - admin: str = "", - days: int = 0, - recurring: bool = False, -): - client = Client.load_state(filename=f"{admin}.apihub") - if client.token is None: - cli.echo("You need to login first") - sys.exit(1) - - expires_at = datetime.now() + timedelta(days=days) if days else None - client.create_subscription( - SubscriptionIn( - username=username, - application=application, - starts_at=datetime.now(), - expires_at=expires_at, - recurring=recurring, - ) - ) - - -@cli.command() -def post_request(application: str, username: str): - client = try_load_state(username) - - data = json.loads(sys.stdin.read()) - response = client.async_request(application, data) - print(response) - return - - MARKER = "# Please write body in json above" - message = cli.edit("\n\n" + MARKER) - if message is not None: - data = json.loads(message.split(MARKER, 1)[0]) - response = client.async_request(application, data) - print(response) - else: - cli.echo("input cannot be empty") - sys.exit(1) - - -@cli.command() -def fetch_result(application: str, key: str, username: str = ""): - client = try_load_state(username) - - response = client.async_result(application, key) - print(response) - - -@cli.command() -def batch_request( - username: str, application: str, filename: str = "", parallel: int = 10 -): - jobs = [] - client = try_load_state(username) - files = [filename] if filename else [] - for i, line in enumerate(fileinput.input(files=files)): - data = json.loads(line) - response = client.async_request(application, data) - if "success" not in response or not response["success"]: - print(f"Failed on input line {i}, request failed", file=sys.stderr) - print(response) - sys.exit(1) - - key = response["key"] - if len(jobs) < parallel: - jobs.append((i, data, key)) - else: - unfinished_jobs = [] - for i, data, key in jobs: - response = client.async_result(application, key) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - unfinished_jobs.append((i, data, key)) - - jobs = unfinished_jobs - - time.sleep(1) - - while len(jobs) > 0: - unfinished_jobs = [] - for i, data, key in jobs: - response = client.async_result(application, key) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - unfinished_jobs.append((i, data, key)) - - jobs = unfinished_jobs - - -@cli.command() -def batch_sync_request( - username: str, application: str, filename: str = "", parallel: int = 10 -): - queue = asyncio.Queue() - - async def request_routine(client, application, data): - response = await client.sync_request(application, data) - if "success" in response and response["success"]: - data.update(response["result"]) - print(i, json.dumps(response["result"])) - else: - print(i, response) - - client = try_load_state(username) - files = [filename] if filename else [] - for i, line in enumerate(fileinput.input(files=files)): - data = json.loads(line) - - queue.put(request_routine(client, application, data)) - time.sleep(1) - - -if __name__ == "__main__": - cli() diff --git a/apihub/client.py b/apihub/client.py deleted file mode 100644 index 55257ba..0000000 --- a/apihub/client.py +++ /dev/null @@ -1,154 +0,0 @@ -import json -from typing import Optional, Any, Dict - -import requests -from pydantic import BaseSettings - -from apihub.security.router import UserCreate -from .subscription.schemas import SubscriptionIn - - -class ClientSettings(BaseSettings): - endpoint: str = "http://localhost" - token: Optional[str] = None - - -class Client: - def __init__(self, settings: Dict[str, Any]) -> None: - if settings is not None: - self.settings = ClientSettings.parse_obj(settings) - else: - self.settings = ClientSettings() - self.token: Optional[str] = None - self.applications: Dict[str, Any] = {} - - def _make_url(self, path: str) -> str: - return f"{self.settings.endpoint}/{path}" - - def save_state(self, filename="~/.apihubrc") -> None: - json.dump( - { - "settings": self.settings.dict(), - "token": self.token, - "applications": self.applications, - }, - open(filename, "w"), - ) - - @staticmethod - def load_state(filename="~/.apihubrc") -> "Client": - state = json.load(open(filename)) - client = Client(state["settings"]) - client.token = state["token"] - client.applications = state["applications"] - return client - - def authenticate(self, username: str, password: str) -> None: - # TODO exceptions - response = requests.get( - self._make_url("_authenticate"), - auth=(username, password), - ) - if response.status_code == 200: - print(response.json()) - self.token = response.json()["access_token"] - else: - print(response.json()) - - def create_user(self, user): - username = user["username"] - response = requests.post( - self._make_url(f"user/{username}"), - headers={"Authorization": f"Bearer {self.token}"}, - json=UserCreate.parse_obj(user).dict(), - ) - if response.status_code == 200: - return True - - def get_users_by_role(self, role): - response = requests.get( - self._make_url(f"users/{role}"), - headers={"Authorization": f"Bearer {self.token}"}, - ) - if response.status_code == 200: - return response.json() - - def create_subscription(self, subscription: SubscriptionIn): - response = requests.post( - self._make_url("subscription"), - headers={"Authorization": f"Bearer {self.token}"}, - data=subscription.json(), - ) - if response.status_code == 200: - return True - else: - raise Exception(response.json()) - - def refresh_application_token( - self, application: str, username: str, expires_days: int - ) -> None: - # TODO exceptions - params = {} - if username: - params["username"] = username - if expires_days: - params["expires_days"] = expires_days - - response = requests.get( - self._make_url(f"token/{application}"), - headers={"Authorization": f"Bearer {self.token}"}, - params=params, - ) - if response.status_code == 200: - print(response.text) - print(response.json()) - self.applications[application] = response.json()["token"] - return response.json()["token"] - else: - print(response.text) - print(response.json()) - - def _check_token_for(self, application: str) -> bool: - return application in self.applications - - def async_request(self, application: str, data: dict): - response = requests.post( - self._make_url(f"async/{application}"), - headers={ - "Authorization": f"Bearer {self.applications[application]}", - }, - json=data, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() - - def async_result(self, application: str, key: str): - # TODO wait and timeout - response = requests.get( - self._make_url(f"async/{application}"), - params={ - "key": key, - }, - headers={ - "Authorization": f"Bearer {self.token}", - }, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() - - def sync_request(self, application: str, data: dict): - response = requests.post( - self._make_url(f"sync/{application}"), - headers={ - "Authorization": f"Bearer {self.applications[application]}", - }, - json=data, - ) - if response.status_code == 200: - return response.json() - else: - return response.json() diff --git a/apihub/result.py b/apihub/result.py index 2b170b2..11b9eb9 100644 --- a/apihub/result.py +++ b/apihub/result.py @@ -4,8 +4,6 @@ from pipeline import ProcessorSettings, Processor, Command, CommandActions, Definition -from .activity.queries import ActivityQuery -from .activity.schemas import ActivityStatus from .common.db_session import create_session from .utils import Result, RedisSettings, DefinitionManager from . import __worker__, __version__ @@ -61,10 +59,10 @@ def process_command(self, command: Command) -> None: def process(self, message_content, message_id): self.logger.info("Processing MESSAGE") result = Result.parse_obj(message_content) - if result.status == ActivityStatus.PROCESSED: - result.result = { - k: message_content.get(k) for k in self.message.logs[-1].updated - } + # if result.status == ActivityStatus.PROCESSED: + # result.result = { + # k: message_content.get(k) for k in self.message.logs[-1].updated + # } self.api_counter.labels(api=result.api, user=result.user, status=result.status) @@ -74,10 +72,10 @@ def process(self, message_content, message_id): r = result.json() self.redis.set(message_id, r, ex=86400) - if result.status == ActivityStatus.PROCESSED: - ActivityQuery(self.session).update_activity( - message_id, **{"status": ActivityStatus.PROCESSED} - ) + # if result.status == ActivityStatus.PROCESSED: + # ActivityQuery(self.session).update_activity( + # message_id, **{"status": ActivityStatus.PROCESSED} + # ) return None diff --git a/apihub/security/depends.py b/apihub/security/depends.py index 01b62de..13156ec 100644 --- a/apihub/security/depends.py +++ b/apihub/security/depends.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, Depends, Request from fastapi_jwt_auth import AuthJWT -from .schemas import UserBase +from .schemas import UserBaseWithId, SecurityToken HTTP_429_TOO_MANY_REQUESTS = 429 @@ -52,10 +52,15 @@ def __init__(self, role: Optional[str] = None, roles: List[str] = list()): def __call__(self, Authorize: AuthJWT = Depends()): Authorize.jwt_required() - roles = Authorize.get_raw_jwt().get("roles", {}) - if any(role in roles for role in self.roles): - username = Authorize.get_jwt_subject() - return username + token = SecurityToken.from_token(Authorize) + + if token.role in self.roles: + return UserBaseWithId( + id=token.user_id, + name=token.name, + email=token.email, + role=token.role, + ) raise HTTPException( HTTP_403_FORBIDDEN, @@ -63,18 +68,22 @@ def __call__(self, Authorize: AuthJWT = Depends()): ) -def require_token(Authorize: AuthJWT = Depends()) -> UserBase: +def require_token(Authorize: AuthJWT = Depends()) -> UserBaseWithId: Authorize.jwt_required() - roles = Authorize.get_raw_jwt()["roles"] - username = Authorize.get_jwt_subject() - return UserBase( - username=username, - role=roles[0], + + token = SecurityToken.from_token(Authorize) + + return UserBaseWithId( + id=token.user_id, + name=token.name, + email=token.email, + role=token.role, ) require_admin = UserOfRole(role="admin") -require_manager = UserOfRole(role="manager") +require_publisher = UserOfRole(role="publisher") require_user = UserOfRole(role="user") require_app = UserOfRole(role="app") -require_manager_or_admin = UserOfRole(roles=["admin", "manager"]) +require_publisher_or_admin = UserOfRole(roles=["admin", "publisher"]) +require_logged_in = UserOfRole(roles=["admin", "publisher", "user", "app"]) diff --git a/apihub/security/helpers.py b/apihub/security/helpers.py index 1907a00..90b7682 100644 --- a/apihub/security/helpers.py +++ b/apihub/security/helpers.py @@ -1,6 +1,8 @@ import os +import datetime import hashlib from base64 import b64encode, b64decode +from fastapi_jwt_auth import AuthJWT def hash_password(password, salt=None): @@ -20,5 +22,6 @@ def hash_password(password, salt=None): password.encode("utf-8"), salt_, 100000, + dklen=64, ).hex() - return salt, hashed_password + return salt, hashed_password \ No newline at end of file diff --git a/apihub/security/models.py b/apihub/security/models.py index 2a09716..47ce2a9 100644 --- a/apihub/security/models.py +++ b/apihub/security/models.py @@ -8,6 +8,7 @@ String, DateTime, Enum, + ForeignKey, ) from .schemas import UserType @@ -22,14 +23,35 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True, index=True, nullable=False) - email = Column(String, index=True) + email = Column(String, index=True, unique=True, nullable=False) + name = Column(String, nullable=False) salt = Column(String) hashed_password = Column(String) role = Column(Enum(UserType), default=UserType.USER) is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.now()) - subscriptions = relationship("Subscription", back_populates="user") + + profile = relationship("Profile", uselist=False, back_populates="user") + + def __str__(self): + return f"{self.email} || {self.role} || {self.is_active}" + + +class Profile(Base): + """ + This class is used to store user profile data. + """ + __tablename__ = "profiles" + + id = Column(Integer, primary_key=True, index=True) + first_name = Column(String) + last_name = Column(String) + bio = Column(String) + url = Column(String) + avatar = Column(String) + + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", cascade = "all,delete", back_populates="profile") def __str__(self): - return f"{self.username} || {self.role}" + return f"{self.user_id} || {self.name}" diff --git a/apihub/security/queries.py b/apihub/security/queries.py index 1b99dc8..c3156f4 100644 --- a/apihub/security/queries.py +++ b/apihub/security/queries.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.orm.exc import NoResultFound @@ -28,26 +28,30 @@ def get_user_by_id(self, user_id: int) -> UserSession: :param user_id: integer id :return: UserSession object. """ - user = self.get_query().filter(User.id == user_id).one() - return UserSession( - id=user.id, - username=user.username, - role=user.role, - salt=user.salt, - hashed_password=user.hashed_password, - ) + try: + user = self.get_query().filter(User.id == user_id).one() + return UserSession( + id=user.id, + name=user.name, + email=user.email, + role=user.role, + salt=user.salt, + hashed_password=user.hashed_password, + ) + except NoResultFound: + raise UserException("User not found.") - def get_user_by_username_and_password( - self, username: str, password: str + def get_user_by_email_and_password( + self, email: str, password: str ) -> UserSession: """ - Get user by username and password. - :param username: str + Get user by email and password. + :param email: str :param password: str :return: UserSession object. """ try: - user = self.get_query().filter(User.username == username).one() + user = self.get_query().filter(User.email == email).one() except NoResultFound: raise UserException @@ -55,47 +59,50 @@ def get_user_by_username_and_password( if user.hashed_password == hashed_password: return UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, ) raise UserException - def get_user_by_username(self, username: str) -> UserSession: + def get_user_by_email(self, email: str) -> UserSession: """ - Get user by username. - :param username: str + Get user by email. + :param email: str :return: UserSession object. """ try: - user = self.get_query().filter(User.username == username).one() + user = self.get_query().filter(User.email == email).one() except NoResultFound: raise UserException return UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, ) - def get_users_by_usernames(self, usernames: List[str]) -> List[UserSession]: + def get_users_by_emails(self, emails: List[str]) -> List[UserSession]: """ - Get users by usernames. - :param usernames: str + Get users by emails. + :param emails: str :return: list of UserSession object. """ try: - users = self.get_query().filter(User.username.in_(usernames)) + users = self.get_query().filter(User.email.in_(emails)) except NoResultFound: raise UserException return [ UserSession( id=user.id, - username=user.username, + name=user.name, + email=user.email, role=user.role, salt=user.salt, hashed_password=user.hashed_password, @@ -116,13 +123,14 @@ def get_users_by_role(self, role) -> List[UserBase]: return [ UserBase( - username=user.username, + name=user.name, + email=user.email, role=user.role, ) for user in users ] - def create_user(self, user: UserCreate) -> bool: + def create_user(self, user: UserCreate) -> Optional[int]: """ Create a new user. :param user: UserCreate object. @@ -136,19 +144,19 @@ def create_user(self, user: UserCreate) -> bool: self.session.rollback() return False - return True + return db_user.id - def change_password(self, username: str, password: str) -> bool: + def change_password(self, email: str, password: str) -> bool: """ Change password for a user. - :param username: str + :param email: str :param password: str :return: boolean. """ try: - user_in_db = self.get_query().filter(User.username == username).one() + user_in_db = self.get_query().filter(User.email == email).one() except NoResultFound: - raise UserException + raise UserException(f"User {email} not found") _, hashed_password = hash_password(password, salt=user_in_db.salt) user_in_db.hashed_password = hashed_password diff --git a/apihub/security/router.py b/apihub/security/router.py index e52038a..ac5eb5a 100644 --- a/apihub/security/router.py +++ b/apihub/security/router.py @@ -6,7 +6,7 @@ from fastapi_jwt_auth import AuthJWT from ..common.db_session import create_session -from .schemas import UserCreate, UserBase, UserRegister, UserType +from .schemas import UserCreate, UserBase, UserRegister, UserType, SecurityToken from .queries import UserQuery, UserException from .depends import require_token, require_admin, require_app @@ -27,14 +27,7 @@ def get_config(): return SecuritySettings() -class AuthenticateResponse(BaseModel): - username: str - roles: List[str] - access_token: str - expires_time: int - - -@router.get("/_authenticate") +@router.get("/_authenticate", response_model=SecurityToken) async def _authenticate( credentials: HTTPBasicCredentials = Depends(security), expires_days: int = 1, @@ -42,68 +35,54 @@ async def _authenticate( ): query = UserQuery(session) try: - user = query.get_user_by_username_and_password( - username=credentials.username, + user = query.get_user_by_email_and_password( + email=credentials.username, password=credentials.password, ) except UserException: raise HTTPException(HTTP_403_FORBIDDEN, "User not found or wrong password") - roles = [user.role] - # make sure the max expires_days won't exceed setting if expires_days > SecuritySettings().security_token_expires_time: expires_days = SecuritySettings().security_token_expires_time - Authorize = AuthJWT() - expires_time = datetime.timedelta(days=expires_days) - access_token = Authorize.create_access_token( - subject=user.username, - user_claims={"roles": roles}, - expires_time=expires_time, - ) - return AuthenticateResponse( - username=user.username, - roles=roles, - expires_time=expires_time.seconds, - access_token=access_token, + security_token = SecurityToken( + email=user.email, + role=user.role, + name=user.name, + user_id=user.id, + expires_days=expires_days, ) + return security_token + -@router.get("/user/{username}") +@router.get("/user") async def get_user( - username: str, - current_user: UserBase = Depends(require_token), + user: UserBase = Depends(require_token), session=Depends(create_session), ): - if current_user.is_admin or current_user.is_manager: - if username == "me": - username = current_user.username - elif username == "me": - username = current_user.username - else: - raise HTTPException(HTTP_403_FORBIDDEN, "You have no permission") - query = UserQuery(session) - user = query.get_user_by_username(username=username) + user = query.get_user_by_email(email=user.email) return UserBase( - username=user.username, + name=user.name, + email=user.email, role=user.role, ) class GetUserAdminIn(BaseModel): - usernames: str + emails: str @router.get("/user") async def get_user_admin( - usernames: GetUserAdminIn, - current_username: str = Depends(require_admin), + group: GetUserAdminIn, + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - users = query.get_users_by_usernames(usernames=usernames.usernames.split(",")) + users = query.get_users_by_emails(emails=group.emails.split(",")) return [UserBase(**user.dict()) for user in users] @@ -114,21 +93,21 @@ class ChangePasswordIn(BaseModel): @router.post("/user/_password") async def change_password( password: ChangePasswordIn, - current_user: UserBase = Depends(require_token), + user: UserBase = Depends(require_token), session=Depends(create_session), ): query = UserQuery(session) - query.change_password(username=current_user.username, password=password.password) + query.change_password(email=user.email, password=password.password) @router.post("/user") async def create_user( - new_user: UserCreate, - current_username: str = Depends(require_admin), + user: UserCreate, + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - query.create_user(new_user) + query.create_user(user) # TODO handling results return {} @@ -136,7 +115,7 @@ async def create_user( @router.get("/users/{role}") async def list_users( role: str, - current_username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) @@ -144,29 +123,29 @@ async def list_users( return users -@router.post("/user/{username}/_password") +@router.post("/user/{email}/_password") async def change_password_admin( - username: str, + email: str, password: ChangePasswordIn, - current_username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): query = UserQuery(session) - query.change_password(username=username, password=password.password) + query.change_password(email=email, password=password.password) @router.post("/register") async def register_user( - new_user: UserRegister, - current_username: str = Depends(require_app), + user: UserRegister, + app: str = Depends(require_app), # FIXME session=Depends(create_session), ): query = UserQuery(session) query.create_user( UserCreate( - username=new_user.username, - email=new_user.email, - password=new_user.password, + name=user.name, + email=user.email, + password=user.password, role=UserType.USER, ) ) diff --git a/apihub/security/schemas.py b/apihub/security/schemas.py index d6ad0a9..eda1d93 100644 --- a/apihub/security/schemas.py +++ b/apihub/security/schemas.py @@ -1,19 +1,61 @@ +import datetime from enum import Enum +from typing import Optional from pydantic import BaseModel +from fastapi_jwt_auth import AuthJWT from .helpers import hash_password +class SecurityToken(BaseModel): + email: str + role: str + name: str + user_id: int + expires_days: Optional[int] = None + access_token: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.access_token is None: + self.access_token = self.to_token() + + def to_token(self): + Authorize = AuthJWT() + expires_time = datetime.timedelta(days=self.expires_days) + access_token = Authorize.create_access_token( + subject=self.email, + user_claims={ + "role": self.role, "name": self.name, "user_id": self.user_id, + }, + expires_time=expires_time, + ) + return access_token + + @classmethod + def from_token(cls, Authorize: AuthJWT): + email = Authorize.get_jwt_subject() + claims = Authorize.get_raw_jwt() + return cls( + email=email, + role=claims["role"], + name=claims["name"], + user_id=claims["user_id"], + expires_days=0, + ) + + class UserType(str, Enum): USER = "user" - MANAGER = "manager" + PUBLISHER = "publisher" APP = "app" ADMIN = "admin" class UserBase(BaseModel): - username: str + email: str + name: str role: UserType @property @@ -21,8 +63,8 @@ def is_admin(self) -> bool: return self.role == UserType.ADMIN @property - def is_manager(self) -> bool: - return self.role == UserType.MANAGER + def is_publisher(self) -> bool: + return self.role == UserType.PUBLISHER @property def is_user(self) -> bool: @@ -33,14 +75,17 @@ def is_app(self) -> bool: return self.role == UserType.APP +class UserBaseWithId(UserBase): + id: int + + class UserRegister(BaseModel): - username: str + name: str email: str password: str class UserCreate(UserBase): - email: str password: str def make_user(self): @@ -50,7 +95,7 @@ def make_user(self): """ salt, hashed_password = hash_password(self.password) return UserCreateHashed( - username=self.username, + name=self.name, email=self.email, salt=salt, hashed_password=hashed_password, @@ -71,3 +116,11 @@ class UserSession(UserBase): class User(UserSession): pass + + +class ProfileBase(BaseModel): + first_name: str + last_name: str + bio: Optional[str] + url: Optional[str] + avatar: Optional[str] \ No newline at end of file diff --git a/apihub/server.py b/apihub/server.py index 6249a53..716b4ce 100644 --- a/apihub/server.py +++ b/apihub/server.py @@ -1,9 +1,11 @@ import sys import functools +from functools import partial import logging -from typing import Dict, Any +from typing import Coroutine, Dict, Any -from fastapi import FastAPI, HTTPException, Request, Query, Depends, BackgroundTasks +from fastapi import FastAPI, HTTPException, Request, Query, Depends +from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException @@ -13,14 +15,13 @@ from dotenv import load_dotenv from pipeline import Message, Settings, Command, CommandActions, Monitor -from .common.db_session import db_context -from .activity.schemas import ActivityStatus, ActivityCreate -from .activity.queries import ActivityQuery +from .common.db_session import create_session +from .activity.schemas import ActivityStatus +from .activity.middlewares import ActivityLogger from .security.depends import RateLimiter, RateLimits, require_user from .security.router import router as security_router -from .subscription.depends import require_subscription +from .subscription.depends import require_subscription, SubscriptionToken from .subscription.router import router as subscription_router -from .subscription.schemas import SubscriptionBase from .utils import ( State, make_topic, @@ -86,6 +87,8 @@ def jwt_get_config(): subscription_router, tags=["subscription"], dependencies=[Depends(ip_rate_limited)] ) +api.add_middleware(ActivityLogger) + @api.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): @@ -126,7 +129,7 @@ async def define_service( return {"define": f"application {application}"} -async def make_request(username: str, application: str, request: Request): +async def make_request(email: str, application: str, request: Request): """Make request to application""" key = make_key() @@ -149,7 +152,7 @@ async def make_request(username: str, application: str, request: Request): # inject user information info = Result( - user=username, + user=email, api=application, status=ActivityStatus.ACCEPTED, ) @@ -164,12 +167,12 @@ async def make_request(username: str, application: str, request: Request): return key -def fetch_result(username: str, application: str, key: str): +def fetch_result(email: str, application: str, key: str): """fetch result""" result = get_redis().get(key) if result is None: operation_counter.labels( - api=application, user=username, operation="result_not_found" + api=application, user=email, operation="result_not_found" ).inc() raise HTTPException( status_code=404, @@ -185,7 +188,7 @@ def fetch_result(username: str, application: str, key: str): ) elif result.status != ActivityStatus.PROCESSED: operation_counter.labels( - api=application, user=username, operation="error" + api=application, user=email, operation="error" ).inc() # FIXME change status code raise HTTPException( @@ -204,37 +207,16 @@ def fetch_result(username: str, application: str, key: str): ) async def async_service( request: Request, - # background_tasks: BackgroundTasks, - subscription: SubscriptionBase = Depends(require_subscription), + subscription: SubscriptionToken = Depends(require_subscription), ): """generic handler for async api.""" - username = subscription.username - tier = subscription.tier - application = subscription.application - operation_counter.labels(api=application, user=username, operation="received").inc() - - key = await make_request(username, application, request) - - operation_counter.labels(api=application, user=username, operation="accepted").inc() - - # activity = ActivityCreate( - # request=f"/async/{application}", - # username=username, - # tier=tier, - # status=ActivityStatus.ACCEPTED, - # request_key=str(key), - # result=str(info.dict()), - # payload=str(dct), - # ip_address=str(request.client.host), - # latency=0.0, - # ) - # - # def add_activity_task(activity): - # with db_context() as session: - # ActivityQuery(session).create_activity(activity) - # - # background_tasks.add_task(add_activity_task, activity=activity) + operation_counter.labels(api=subscription.application, user=subscription.email, operation="received").inc() + + key = await make_request(subscription.email, subscription.application, request) + + operation_counter.labels(api=subscription.application, user=subscription.email, operation="accepted").inc() + return AsyncAPIRequestResponse(success=True, key=key) diff --git a/apihub/subscription/depends.py b/apihub/subscription/depends.py index d3966d9..630bf10 100644 --- a/apihub/subscription/depends.py +++ b/apihub/subscription/depends.py @@ -1,3 +1,4 @@ +from pydantic import BaseModel from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from redis import Redis @@ -5,7 +6,7 @@ from ..common.db_session import create_session from ..common.redis_session import redis_conn -from .schemas import SubscriptionBase +from .schemas import SubscriptionToken from .queries import SubscriptionQuery from .helpers import make_key, BALANCE_KEYS @@ -16,7 +17,7 @@ def require_subscription( application: str, Authorize: AuthJWT = Depends() -) -> SubscriptionBase: +) -> SubscriptionToken: """ This function is used to check if the user has a valid subscription token. :param application: str @@ -24,43 +25,36 @@ def require_subscription( :return: SubscriptionBase object. """ Authorize.jwt_required() - username = Authorize.get_jwt_subject() + subscription_token = SubscriptionToken.from_token(Authorize) - claims = Authorize.get_raw_jwt() - subscription_claim = claims.get("subscription") - tier_claim = claims.get("tier") - if subscription_claim != application: + if subscription_token.application != application: raise HTTPException( HTTP_403_FORBIDDEN, "The API key doesn't have permission to perform the request", ) - return SubscriptionBase( - username=username, tier=tier_claim, application=subscription_claim - ) + + return subscription_token def require_subscription_balance( - subscription: SubscriptionBase = Depends(require_subscription), + subscription: SubscriptionToken = Depends(require_subscription), redis: Redis = Depends(redis_conn), session=Depends(create_session), -) -> str: +) -> SubscriptionToken: """ This function is used to check if the user has enough balance to perform. :param subscription: str :param redis: Redis object. :param session: Session object. - :return: username str. + :return: email str. """ - username = subscription.username - tier = subscription.tier - application = subscription.application - - key = make_key(username, application, tier) + key = make_key(subscription) balance = redis.decr(key) + - if balance == -1: - subscription = SubscriptionQuery(session).get_active_subscription( - username, application + if balance is None or balance == -1: + subscription = SubscriptionQuery(session).get_subscription( + subscription.subscription_id ) balance = subscription.credit - subscription.balance - 1 if balance > 0: @@ -69,7 +63,7 @@ def require_subscription_balance( if balance <= 0: SubscriptionQuery(session).update_balance_in_subscription( - username, application, tier, redis + subscription, redis ) if balance < 0: @@ -78,4 +72,4 @@ def require_subscription_balance( "You have used up all credit for this API", ) - return username + return subscription \ No newline at end of file diff --git a/apihub/subscription/helpers.py b/apihub/subscription/helpers.py index f39ac29..17960c7 100644 --- a/apihub/subscription/helpers.py +++ b/apihub/subscription/helpers.py @@ -6,30 +6,23 @@ BALANCE_KEYS = "balance:keys" -def make_key(username: str, application: str, tier: str) -> str: - """ - Make key for redis. - :param username: str - :param application: str - :param tier: str - :return: str. - """ - return f"balance:{username}:{application}:{tier}" +def make_key(subscription) -> str: + return f"balance:{subscription.user_id}:{subscription.application_id}:{subscription.tier}" @contextmanager def get_and_reset_balance_in_cache( - username: str, application: str, tier: str, redis: Redis + subscription, redis: Redis ) -> None: """ Get balance from cache and delete it. - :param username: str + :param email: str :param application: str :param tier: str :param redis: Redis object. :return: None """ - key = make_key(username, application, tier) + key = make_key(subscription) balance = redis.get(key) yield int(balance) diff --git a/apihub/subscription/models.py b/apihub/subscription/models.py index 23b3619..1c50b66 100644 --- a/apihub/subscription/models.py +++ b/apihub/subscription/models.py @@ -13,7 +13,11 @@ from sqlalchemy.orm import relationship from ..common.db_session import Base -from .schemas import SubscriptionTier, ApplicationCreate, SubscriptionPricingCreate +from .schemas import SubscriptionTier, ApplicationCreate, PricingCreate + + +def set_default_path(context): + return context.get_current_parameters()["name"].lower().replace(" ", "-") class Application(Base): @@ -21,15 +25,21 @@ class Application(Base): This class is used to store application data. """ - __tablename__ = "application" + __tablename__ = "applications" id = Column(Integer, primary_key=True, index=True) name = Column(String, unique=True, index=True, nullable=False) + path = Column(String, unique=True, index=True, nullable=False, default=set_default_path) url = Column(String) description = Column(String) + is_active = Column(Boolean, default=True) - subscriptions = relationship("Subscription", backref="app") - subscriptions_pricing = relationship("SubscriptionPricing", backref="app") + created_at = Column(DateTime, default=datetime.now()) + user_id = Column(Integer, ForeignKey("users.id")) + + user = relationship("User") + subscriptions = relationship("Subscription", back_populates="application") + pricings = relationship("Pricing", back_populates="application") def __str__(self): return f"{self.name} || {self.url}" @@ -39,32 +49,41 @@ def to_schema(self, with_pricing=False) -> ApplicationCreate: name=self.name, url=self.url, description=self.description, - pricing=[ - SubscriptionPricingCreate( + pricings=[ + PricingCreate( tier=pricing.tier, price=pricing.price, credit=pricing.credit, application=self.name, ) - for pricing in self.subscriptions_pricing + for pricing in self.pricings ] if with_pricing else [], ) -class SubscriptionPricing(Base): +class Pricing(Base): """ This class is used to store subscription pricing data. """ - __tablename__ = "subscription_pricing" - __table_args__ = (UniqueConstraint("application", "tier", name="application_tier"),) + __tablename__ = "pricings" + __table_args__ = ( + UniqueConstraint( + "application_id", "tier", name="application_tier_constraint" + ), + ) id = Column(Integer, primary_key=True, index=True) tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) price = Column(Integer) credit = Column(Integer) + is_active = Column(Boolean, default=True) + + created_at = Column(DateTime, default=datetime.now()) + + application_id = Column(Integer, ForeignKey("applications.id"), nullable=False) + application = relationship("Application", uselist=False, back_populates="pricings") - application = Column(String, ForeignKey("application.name"), nullable=False) def __str__(self): - return f"{self.application} || {self.tier} || {self.price}" + return f"{self.application_id} || {self.tier} || {self.price}" class Subscription(Base): @@ -73,28 +92,33 @@ class Subscription(Base): """ __tablename__ = "subscriptions" - __table_args__ = ( - UniqueConstraint( - "application", "tier", "username", name="application_tier_username" - ), - ) + # __table_args__ = ( + # UniqueConstraint( + # "application_id", "tier", "user_id", name="application_tier_user_constraint" + # ), + # ) id = Column(Integer, primary_key=True, index=True) tier = Column(Enum(SubscriptionTier), default=SubscriptionTier.TRIAL) - active = Column(Boolean, default=True) + is_active = Column(Boolean, default=True) credit = Column(Integer, default=0) balance = Column(Integer, default=0) starts_at = Column(DateTime, default=datetime.now()) expires_at = Column(DateTime) recurring = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.now()) - created_by = Column(String) + # created_by = Column(Integer, ForeignKey("users.id")) notes = Column(String) - username = Column(String, ForeignKey("users.username"), nullable=False) - user = relationship("User", back_populates="subscriptions") + user_id = Column(Integer, ForeignKey("users.id")) + owner = relationship("User") + + application_id = Column(Integer, ForeignKey("applications.id"), nullable=False) + application = relationship("Application", uselist=False, back_populates="subscriptions") + + pricing_id = Column(Integer, ForeignKey("pricings.id"), nullable=False) + pricing = relationship("Pricing", uselist=False) - application = Column(String, ForeignKey("application.name"), nullable=False) def __str__(self): - return f"{self.application} || {self.tier} || {self.username}" + return f"{self.application} || {self.tier} || {self.email}" diff --git a/apihub/subscription/queries.py b/apihub/subscription/queries.py index 2f3656d..fbd87fb 100644 --- a/apihub/subscription/queries.py +++ b/apihub/subscription/queries.py @@ -8,12 +8,13 @@ from sqlalchemy.orm import Query from ..common.queries import BaseQuery -from .models import Subscription, Application, SubscriptionPricing +from .models import Subscription, Application, Pricing from .schemas import ( SubscriptionCreate, SubscriptionDetails, ApplicationCreate, - SubscriptionPricingBase, + ApplicationCreateWithOwner, + PricingBase, ) from .helpers import get_and_reset_balance_in_cache @@ -22,7 +23,7 @@ class ApplicationException(Exception): pass -class SubscriptionPricingException(Exception): +class PricingException(Exception): pass @@ -38,27 +39,26 @@ def get_query(self) -> Query: """ return self.session.query(Application) - def create_application(self, application: ApplicationCreate): + def create_application(self, application: ApplicationCreateWithOwner) -> ApplicationCreateWithOwner: """ Create an application. :param application: Application details. :return: Application object. """ - pricing_list = [] - for pricing in application.pricing: - pricing_list.append( - SubscriptionPricing( + pricings = [] + for pricing in application.pricings: + pricings.append( + Pricing( tier=pricing.tier, price=pricing.price, credit=pricing.credit, - application=application.name, ) ) application_object = Application( name=application.name, url=application.url, description=application.description, - subscriptions_pricing=pricing_list, + pricings=pricings, ) try: self.session.add(application_object) @@ -68,19 +68,31 @@ def create_application(self, application: ApplicationCreate): self.session.rollback() raise ApplicationException(f"Error creating application: {e}") - def get_application(self, name: str) -> ApplicationCreate: + def get_application(self, application_id: int) -> ApplicationCreate: """ Get application by name. :param name: Application name. :return: application object. """ try: - application = self.get_query().filter(Application.name == name).one() + application = self.get_query().filter(Application.id == application_id).one() return application.to_schema(with_pricing=True) except NoResultFound: - raise ApplicationException(f"Application {name} not found.") + raise ApplicationException(f"Application {application_id} not found.") - def get_applications(self, username=None) -> List[ApplicationCreate]: + def get_application_by_path(self, path: str) -> ApplicationCreate: + """ + Get application by name. + :param name: Application name. + :return: application object. + """ + try: + application = self.get_query().filter(Application.path == path).one() + return application.to_schema(with_pricing=True) + except NoResultFound: + raise ApplicationException(f"Application with path {path} not found.") + + def get_applications(self, email=None) -> List[ApplicationCreate]: """ List applications. :return: List of applications. @@ -89,26 +101,26 @@ def get_applications(self, username=None) -> List[ApplicationCreate]: return list(applications) -class SubscriptionPricingQuery(BaseQuery): +class PricingQuery(BaseQuery): def get_query(self) -> Query: """ Get query object. :return: Query object. """ - return self.session.query(SubscriptionPricing) + return self.session.query(Pricing) def create_subscription_pricing( self, tier: str, application: str, price: int, credit: int - ) -> SubscriptionPricing: + ) -> Pricing: """ Create subscription pricing. :param tier: Subscription tier. :param application: Application name. :param price: Price. :param credit: Credit. - :return: SubscriptionPricing object. + :return: Pricing object. """ - subscription_pricing = SubscriptionPricing( + subscription_pricing = Pricing( tier=tier, application=application, price=price, @@ -120,31 +132,31 @@ def create_subscription_pricing( return subscription_pricing except Exception as e: self.session.rollback() - raise SubscriptionPricingException( + raise PricingException( f"Error creating subscription pricing: {e}" ) def get_subscription_pricing( - self, application: str, tier: str - ) -> SubscriptionPricing: + self, application_id: int, tier: str + ) -> Pricing: """ Get subscription pricing by application name and tier. :param application: Application name. :param tier: Subscription tier. - :return: SubscriptionPricing object. + :return: Pricing object. """ try: return ( self.get_query() .filter( - SubscriptionPricing.application == application, - SubscriptionPricing.tier == tier, + Pricing.application_id == application_id, + Pricing.tier == tier, ) .one() ) except NoResultFound: - raise SubscriptionPricingException( - f"Subscription pricing for application {application} and tier {tier} not found." + raise PricingException( + f"Subscription pricing for application {application_id} and tier {tier} not found." ) @@ -166,7 +178,7 @@ def create_subscription(self, subscription_create: SubscriptionCreate): found_existing_subscription = True try: self.get_active_subscription( - subscription_create.username, subscription_create.application + subscription_create.user_id, subscription_create.application_id ) except SubscriptionException: found_existing_subscription = False @@ -176,22 +188,21 @@ def create_subscription(self, subscription_create: SubscriptionCreate): "Found existing subscription, please delete it before create new subscription" ) - subscription_pricing = SubscriptionPricingQuery( + subscription_pricing = PricingQuery( self.session ).get_subscription_pricing( - subscription_create.application, subscription_create.tier + subscription_create.application_id, subscription_create.tier ) new_subscription = Subscription( - username=subscription_create.username, - application=subscription_create.application, - active=subscription_create.active, + user_id=subscription_create.user_id, + application_id=subscription_create.application_id, + pricing_id=subscription_create.pricing_id, tier=subscription_create.tier, credit=subscription_pricing.credit, balance=subscription_pricing.credit, expires_at=subscription_create.expires_at, recurring=subscription_create.recurring, - created_by=subscription_create.created_by, ) try: self.session.add(new_subscription) @@ -200,12 +211,76 @@ def create_subscription(self, subscription_create: SubscriptionCreate): self.session.rollback() raise SubscriptionException(f"Error creating subscription: {e}") + def get_active_subscription_by_path( + self, user_id: int, path: str + ) -> SubscriptionDetails: + """ + Get active subscription of a user. + :param email: str + :param application: str + :return: SubscriptionDetails object. + """ + try: + application = self.session.query(Application).filter( + Application.path == path + ).one() + + subscription = self.get_query().filter( + Subscription.user_id == user_id, + Subscription.application_id == application.id, + Subscription.is_active == true(), + or_( + Subscription.expires_at.is_(None), + Subscription.expires_at > datetime.now(), + ), + ).one() + except NoResultFound: + raise SubscriptionException("Subscription not found.") + + return SubscriptionDetails( + id=subscription.id, + user_id=subscription.user_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, + tier=subscription.tier, + is_active=subscription.is_active, + credit=subscription.credit, + balance=subscription.balance, + starts_at=subscription.starts_at, + expires_at=subscription.expires_at, + recurring=subscription.recurring, + created_at=subscription.created_at, + ) + + def get_subscription(self, subscription_id: int) -> SubscriptionDetails: + try: + subscription = self.get_query().filter( + Subscription.id == subscription_id + ).one() + except NoResultFound: + raise SubscriptionException("Subscription not found.") + + return SubscriptionDetails( + id=subscription.id, + user_id=subscription.user_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, + tier=subscription.tier, + is_active=subscription.is_active, + credit=subscription.credit, + balance=subscription.balance, + starts_at=subscription.starts_at, + expires_at=subscription.expires_at, + recurring=subscription.recurring, + created_at=subscription.created_at, + ) + def get_active_subscription( - self, username: str, application: str + self, user_id: int, application_id: int ) -> SubscriptionDetails: """ Get active subscription of a user. - :param username: str + :param email: str :param application: str :return: SubscriptionDetails object. """ @@ -213,9 +288,9 @@ def get_active_subscription( subscription = ( self.get_query() .filter( - Subscription.username == username, - Subscription.application == application, - Subscription.active == true(), + Subscription.user_id == user_id, + Subscription.application_id == application_id, + Subscription.is_active == true(), or_( Subscription.expires_at.is_(None), Subscription.expires_at > datetime.now(), @@ -227,29 +302,30 @@ def get_active_subscription( raise SubscriptionException return SubscriptionDetails( - username=username, - application=application, + id=subscription.id, + user_id=subscription.user_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, - active=subscription.active, + is_active=subscription.is_active, credit=subscription.credit, balance=subscription.balance, starts_at=subscription.starts_at, expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=subscription.created_by, created_at=subscription.created_at, ) - def get_active_subscriptions(self, username: str) -> List[SubscriptionDetails]: + def get_active_subscriptions(self, user_id: int) -> List[SubscriptionDetails]: """ Get all active subscriptions of a user. - :param username: str + :param email: str :return: list of SubscriptionDetails. """ try: subscriptions = self.get_query().filter( - Subscription.username == username, - Subscription.active == true(), + Subscription.user_id == user_id, + Subscription.is_active == true(), or_( Subscription.expires_at.is_(None), Subscription.expires_at > datetime.now(), @@ -260,51 +336,35 @@ def get_active_subscriptions(self, username: str) -> List[SubscriptionDetails]: return [ SubscriptionDetails( - username=subscription.username, - application=subscription.application, + id=subscription.id, + user_id=subscription.user_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, - active=subscription.active, + is_active=subscription.is_active, credit=subscription.credit, balance=subscription.balance, expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=subscription.created_by, created_at=subscription.created_at, ) for subscription in subscriptions ] def update_balance_in_subscription( - self, username: str, application: str, tier: str, redis: Redis + self, subscription: Subscription, redis: Redis ) -> None: """ Update balance in subscription. - :param username: str + :param email: str :param application: str :param tier: str :param redis: Redis object. :return: None """ - try: - subscription = ( - self.get_query() - .filter( - Subscription.username == username, - Subscription.application == application, - Subscription.tier == tier, - Subscription.active == true(), - or_( - Subscription.expires_at.is_(None), - Subscription.expires_at > datetime.now(), - ), - ) - .one() - ) - except NoResultFound: - raise SubscriptionException - with get_and_reset_balance_in_cache( - username, application, tier, redis + # FIXME why not use subscription.id? + subscription.user_id, subscription.application_id, subscription.tier, redis ) as balance: subscription.balance = subscription.credit - balance self.session.add(subscription) diff --git a/apihub/subscription/router.py b/apihub/subscription/router.py index 68d153a..e66ef04 100644 --- a/apihub/subscription/router.py +++ b/apihub/subscription/router.py @@ -7,20 +7,22 @@ from ..common.db_session import create_session from ..security.schemas import ( - UserBase, + UserBaseWithId, ) -from ..security.depends import require_admin, require_token +from ..security.depends import require_admin, require_publisher, require_token, require_user, require_logged_in from ..security.queries import UserQuery, UserException from .schemas import ( SubscriptionCreate, SubscriptionIn, ApplicationCreate, + ApplicationCreateWithOwner, + SubscriptionToken, ) from .queries import ( SubscriptionQuery, SubscriptionException, - SubscriptionPricingException, + PricingException, ApplicationQuery, ApplicationException, ) @@ -40,13 +42,18 @@ class SubscriptionSettings(BaseSettings): def create_application( application: ApplicationCreate, session: Session = Depends(create_session), - username: str = Depends(require_admin), + publisher: str = Depends(require_publisher), ): """ Create an application. """ + applicationCreateWithOwner = ApplicationCreateWithOwner.copy( + application, + update={"owner": publisher} + ) + try: - return ApplicationQuery(session).create_application(application) + return ApplicationQuery(session).create_application(applicationCreateWithOwner) except ApplicationException as e: raise HTTPException(status_code=400, detail=str(e)) @@ -54,7 +61,7 @@ def create_application( @router.get("/application", response_model=List[ApplicationCreate]) def get_applications( session: Session = Depends(create_session), - user: UserBase = Depends(require_token), + user: str = Depends(require_logged_in), ): """ List all applications. @@ -67,48 +74,48 @@ def get_applications( raise HTTPException(400, detail=str(e)) -@router.get("/application/{application}", response_model=ApplicationCreate) +@router.get("/application/{path}", response_model=ApplicationCreate) def get_application( - application: str, + path: str, session: Session = Depends(create_session), - username: str = Depends(require_admin), + user: str = Depends(require_logged_in), ): try: """ Get an application. """ - return ApplicationQuery(session).get_application(application) + return ApplicationQuery(session).get_application_by_path(path) except ApplicationException: - raise HTTPException(400, "Error while retrieving applications") + raise HTTPException(400, f"Error while retrieving application with path {path}") @router.post("/subscription") def create_subscription( subscription: SubscriptionIn, - username: str = Depends(require_admin), + admin: str = Depends(require_admin), session=Depends(create_session), ): - # make sure the username exists. + # make sure the email exists. try: - UserQuery(session).get_user_by_username(subscription.username) + UserQuery(session).get_user_by_id(subscription.user_id) except UserException: - raise HTTPException(401, f"User {subscription.username} not found.") + raise HTTPException(401, f"User {subscription.user_id} not found.") # make sure the application is not currently active. try: SubscriptionQuery(session).get_active_subscription( - subscription.username, subscription.application + subscription.user_id, subscription.application_id ) raise HTTPException( - 403, f"Application {subscription.application} already exists." + 403, f"Subscription for applicaiton {subscription.application_id} already exists." ) except SubscriptionException: pass try: - ApplicationQuery(session).get_application(subscription.application) + ApplicationQuery(session).get_application(subscription.application_id) except ApplicationException: - raise HTTPException(404, f"Application {subscription.application} not found.") + raise HTTPException(404, f"Application {subscription.application_id} not found.") if subscription.expires_at is None: subscription.expires_at = datetime.now() + timedelta( @@ -116,13 +123,13 @@ def create_subscription( ) subscription_create = SubscriptionCreate( - username=subscription.username, - application=subscription.application, + user_id=subscription.user_id, + application_id=subscription.application_id, + pricing_id=subscription.pricing_id, tier=subscription.tier, starts_at=datetime.now(), expires_at=subscription.expires_at, recurring=subscription.recurring, - created_by=username, ) try: query = SubscriptionQuery(session) @@ -130,19 +137,19 @@ def create_subscription( return subscription_create except SubscriptionException as e: raise HTTPException(400, str(e)) - except SubscriptionPricingException as e: + except PricingException as e: raise HTTPException(400, str(e)) @router.get("/subscription/{application}") def get_active_subscription( - application: str, - user: UserBase = Depends(require_token), + application: int, + user: UserBaseWithId = Depends(require_logged_in), session=Depends(create_session), ): query = SubscriptionQuery(session) try: - subscription = query.get_active_subscription(user.username, application) + subscription = query.get_active_subscription(user.id, application) except SubscriptionException: raise HTTPException(400, "Subscription not found") @@ -151,43 +158,33 @@ def get_active_subscription( @router.get("/subscription") def get_active_subscriptions( - user: UserBase = Depends(require_token), + user: UserBaseWithId = Depends(require_user), session=Depends(create_session), ): if not user.is_user: return [] - username = user.username query = SubscriptionQuery(session) try: - subscriptions = query.get_active_subscriptions(username) + subscriptions = query.get_active_subscriptions(user.id) except SubscriptionException: return [] - return [ - SubscriptionIn( - username=subscription.username, - application=subscription.application, - tier=subscription.tier, - expires_at=subscription.expires_at, - recurring=subscription.recurring, - ) - for subscription in subscriptions - ] + return subscriptions class SubscriptionTokenResponse(BaseModel): - username: str + email: str application: str token: str expires_time: int -@router.get("/token/{application}") +@router.get("/token/{path}", response_model=SubscriptionToken) async def get_application_token( - application: str, - user: UserBase = Depends(require_token), - username: Optional[str] = None, + path: str, + user: UserBaseWithId = Depends(require_user), + email: Optional[str] = None, expires_days: Optional[ int ] = SubscriptionSettings().subscription_token_expires_days, @@ -196,16 +193,16 @@ async def get_application_token( query = SubscriptionQuery(session) if user.is_user: - username = user.username + email = user.email expires_days = SubscriptionSettings().subscription_token_expires_days else: - if username is None: - raise HTTPException(401, "username is missing") + if email is None: + raise HTTPException(401, "email is missing") try: - subscription = query.get_active_subscription(username, application) + subscription = query.get_active_subscription_by_path(user.id, path) except SubscriptionException: - raise HTTPException(401, f"No active subscription found for user {username}") + raise HTTPException(401, f"No active subscription found for user {email}") if subscription.balance > subscription.credit: raise HTTPException(HTTP_429_TOO_MANY_REQUESTS, "You have used up your credit") @@ -215,16 +212,15 @@ async def get_application_token( if expires_days > subscription_expires_timedelta.days: expires_days = subscription_expires_timedelta.days - Authorize = AuthJWT() - expires_time = timedelta(days=expires_days) - access_token = Authorize.create_access_token( - subject=username, - user_claims={"subscription": application, "tier": subscription.tier}, - expires_time=expires_time, - ) - return SubscriptionTokenResponse( - username=username, - application=application, - token=access_token, - expires_time=expires_time.seconds, + subscription_token = SubscriptionToken( + email=email, + name = user.name, + user_id=user.id, + role=user.role, + application=path, + tier=subscription.tier, + application_id=subscription.application_id, + subscription_id=subscription.id, + expires_days=expires_days, ) + return subscription_token \ No newline at end of file diff --git a/apihub/subscription/schemas.py b/apihub/subscription/schemas.py index 8e662b8..1f5149d 100644 --- a/apihub/subscription/schemas.py +++ b/apihub/subscription/schemas.py @@ -3,6 +3,53 @@ from enum import Enum from pydantic import BaseModel +from fastapi_jwt_auth import AuthJWT + +from ..security.schemas import ( SecurityToken ) + + +class SubscriptionToken(SecurityToken): + subscription_id: int + application_id: int + tier: str + application: str + access_token: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.access_token = self.to_token() + + def to_token(self): + Authorize = AuthJWT() + access_token = Authorize.create_access_token( + subject=self.email, + user_claims={ + "name": self.name, + "role": self.role, + "user_id": self.user_id, + "subscription_id": self.subscription_id, + "application_id": self.application_id, + "tier": self.tier, + "application": self.application, + "expires_days": self.expires_days, + }, + ) + return access_token + + @classmethod + def from_token(cls, Authorize: AuthJWT): + email = Authorize.get_jwt_subject() + claims = Authorize.get_raw_jwt() + return cls( + email=email, + name=claims["name"], + role=claims["role"], + user_id=claims["user_id"], + subscription_id=claims["subscription_id"], + application_id=claims["application_id"], + tier=claims["tier"], + application=claims["application"], + ) class SubscriptionTier(str, Enum): @@ -10,17 +57,17 @@ class SubscriptionTier(str, Enum): STANDARD = "STANDARD" PREMIUM = "PREMIUM" -class SubscriptionPricingBase(BaseModel): +class PricingBase(BaseModel): tier: SubscriptionTier price: int credit: int -class SubscriptionPricingCreate(SubscriptionPricingBase): +class PricingCreate(PricingBase): application: str -class SubscriptionPricingDetails(SubscriptionPricingCreate): +class PricingDetails(PricingCreate): id: int @@ -31,16 +78,25 @@ class ApplicationBase(BaseModel): class ApplicationCreate(ApplicationBase): - pricing: List[SubscriptionPricingBase] + pricings: List[PricingBase] + + +class ApplicationCreateWithOwner(ApplicationCreate): + user_id: int + + +class ApplicationDetailsWithId(ApplicationCreateWithOwner): + id: int class SubscriptionBase(BaseModel): - username: str - application: str + user_id: int + application_id: int tier: SubscriptionTier class SubscriptionIn(SubscriptionBase): + pricing_id: int expires_at: Optional[datetime] = None recurring: bool = False @@ -48,12 +104,10 @@ class SubscriptionIn(SubscriptionBase): class SubscriptionCreate(SubscriptionIn): credit: Optional[int] = 0 starts_at: datetime = datetime.now() - active: bool = True - created_by: str notes: Optional[str] = None - application: str class SubscriptionDetails(SubscriptionCreate): + id: int created_at: datetime - balance: int + balance: int \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 3f08a86..cd1bedb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,26 @@ # This file is automatically @generated by Poetry and should not be changed by hand. +[[package]] +name = "alembic" +version = "1.9.1" +description = "A database migration tool for SQLAlchemy." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "alembic-1.9.1-py3-none-any.whl", hash = "sha256:a9781ed0979a20341c2cbb56bd22bd8db4fc1913f955e705444bd3a97c59fa32"}, + {file = "alembic-1.9.1.tar.gz", hash = "sha256:f9f76e41061f5ebe27d4fe92600df9dd612521a7683f904dab328ba02cffa5a2"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.9\""} +importlib-resources = {version = "*", markers = "python_version < \"3.9\""} +Mako = "*" +SQLAlchemy = ">=1.3.0" + +[package.extras] +tz = ["python-dateutil"] + [[package]] name = "async-timeout" version = "4.0.2" @@ -967,6 +988,27 @@ roundrobin = ">=0.0.2" typing-extensions = ">=3.7.4.3" Werkzeug = ">=2.0.0" +[[package]] +name = "mako" +version = "1.2.4" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"}, + {file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "markupsafe" version = "2.1.1" @@ -2420,4 +2462,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "eec650345fe015d0cafb0ec7b621bea41f0871ab93f6d742fb8f8d76b2ac1593" +content-hash = "99aa50f05c3b71e5387e8b5084f38f2ca49305473c8726b5289f79b3729e95c3" diff --git a/pyproject.toml b/pyproject.toml index 0504ca5..2874670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ apihub_admin = "apihub.admin:create_all_statements" [tool.poetry.group.dev.dependencies] openapi-spec-validator = "^0.5.1" +alembic = "^1.9.1" +pyyaml = "^6.0" [tool.black] line-length = 88 @@ -42,7 +44,7 @@ select = "B,C,E,F,W,T4,B9" [tool.pytest.ini_options] minversion = "6.0" python_files = "*.py" -addopts = "--color=yes -p no:warnings --ignore=templates --ignore=static --doctest-glob *.rst --ignore=apihub" +addopts = "--color=yes -p no:warnings --ignore=templates --ignore=static --doctest-glob *.rst --ignore=apihub --ignore=alembic" doctest_optionflags= "NORMALIZE_WHITESPACE ELLIPSIS" [tool.poetry.dependencies] @@ -74,4 +76,4 @@ locust = "^2.13.0" [build-system] requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 609b637..5bc5336 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ create_database, drop_database, ) +from apihub.security.models import * +from apihub.subscription.models import * +from apihub.activity.models import * DB_ENGINE = get_db_engine() diff --git a/tests/test_activity.py b/tests/test_activity.py deleted file mode 100644 index 27a3bbe..0000000 --- a/tests/test_activity.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -from datetime import datetime -import factory -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from apihub.common.db_session import create_session -from apihub.subscription.router import router -from apihub.subscription.models import SubscriptionTier -from apihub.activity.models import Activity -from apihub.activity.queries import ActivityQuery, ActivityException -from apihub.activity.schemas import ActivityCreate, ActivityStatus - - -class ActivityFactory(factory.alchemy.SQLAlchemyModelFactory): - class Meta: - model = Activity - - id = factory.Sequence(int) - created_at = factory.LazyFunction(datetime.now) - request = factory.Sequence(lambda n: f"app{n}") - username = factory.Sequence(lambda n: f"tester{n}") - tier = SubscriptionTier.TRIAL - status = ActivityStatus.PROCESSED - request_key = factory.Sequence(lambda n: f"request_key{n}") - result = "" - payload = "" - ip_address = "" - latency = 0.0 - - -@pytest.fixture(scope="function") -def client(db_session): - def _create_session(): - try: - yield db_session - finally: - pass - - app = FastAPI() - app.include_router(router) - app.dependency_overrides[create_session] = _create_session - ActivityFactory._meta.sqlalchemy_session = db_session - ActivityFactory._meta.sqlalchemy_session_persistence = "commit" - ActivityFactory( - username="tester", - request="async/app1", - request_key="app1_key", - status=ActivityStatus.PROCESSED, - ) - yield TestClient(app) - - -@pytest.fixture(scope="function") -def query(db_session): - yield ActivityQuery(db_session) - - -class TestActivity: - def test_create_activity(self, query): - query.create_activity( - ActivityCreate( - request="async/test", - username="ahmed", - tier=SubscriptionTier.TRIAL, - status=ActivityStatus.ACCEPTED, - request_key="async/test_key1234", - result="", - payload="", - ip_address="", - latency=0.0, - ) - ) - - assert ( - query.get_activity_by_key("async/test_key1234").request_key - == "async/test_key1234" - ) - - def test_get_activity_by_key(self, client, query): - assert query.get_activity_by_key("app1_key").request_key == "app1_key" - - with pytest.raises(ActivityException): - query.get_activity_by_key("key 2") - - def test_update_activity(self, client, query): - activity = query.get_activity_by_key("app1_key") - assert activity.tier == SubscriptionTier.TRIAL - - query.update_activity( - "app1_key", - **{"tier": SubscriptionTier.STANDARD, "ip_address": "test ip"}, - ) - - activity = query.get_activity_by_key("app1_key") - assert ( - activity.tier == SubscriptionTier.STANDARD - and activity.ip_address == "test ip" - and activity.latency > 0.0 - ) - - with pytest.raises(ActivityException): - query.update_activity( - "not existing", - **{"tier": SubscriptionTier.STANDARD, "ip_address": "test ip"}, - ) diff --git a/tests/test_result.py b/tests/test_result.py index 4898fcd..8242459 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -1,30 +1,30 @@ import pytest -from apihub.activity.queries import ActivityQuery -from apihub.activity.schemas import ActivityStatus -from .test_activity import ActivityFactory +# from apihub.activity.queries import ActivityQuery +# from apihub.activity.schemas import ActivityStatus +# from .test_activity import ActivityFactory message_id = "ab7fe542-bdf2-11eb-b401-f21898b454f0" -@pytest.fixture(scope="function") -def query(db_session): - ActivityFactory._meta.sqlalchemy_session = db_session - ActivityFactory._meta.sqlalchemy_session_persistence = "commit" - ActivityFactory( - username="tester", - request="async/app1", - request_key=message_id, - status=ActivityStatus.ACCEPTED, - ) - yield ActivityQuery(db_session) +# @pytest.fixture(scope="function") +# def query(db_session): +# ActivityFactory._meta.sqlalchemy_session = db_session +# ActivityFactory._meta.sqlalchemy_session_persistence = "commit" +# ActivityFactory( +# username="tester", +# request="async/app1", +# request_key=message_id, +# status=ActivityStatus.ACCEPTED, +# ) +# yield ActivityQuery(db_session) class TestResultWriter: - def test_basic(self, db_session, query, monkeypatch): - activity = query.get_activity_by_key(message_id) - assert activity.status == ActivityStatus.ACCEPTED + def test_basic(self, db_session, monkeypatch): + # activity = query.get_activity_by_key(message_id) + # assert activity.status == ActivityStatus.ACCEPTED monkeypatch.setenv("MONITORING", "FALSE") from apihub.result import ResultWriter @@ -39,8 +39,8 @@ def test_basic(self, db_session, query, monkeypatch): except Exception: pytest.fail("worker raised exception") - activity = query.get_activity_by_key(message_id) - assert activity.status == ActivityStatus.PROCESSED + # activity = query.get_activity_by_key(message_id) + # assert activity.status == ActivityStatus.PROCESSED def test_command(self, monkeypatch): monkeypatch.setenv("MONITORING", "FALSE") diff --git a/tests/test_security.py b/tests/test_security.py index 4b640bd..793a3bf 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -13,8 +13,8 @@ from apihub.common.db_session import create_session from apihub.security.models import User from apihub.security.queries import UserQuery -from apihub.security.schemas import UserCreate, UserType, UserRegister -from apihub.security.router import router, AuthenticateResponse +from apihub.security.schemas import UserCreate, UserType, UserRegister, SecurityToken +from apihub.security.router import router from apihub.security.depends import require_admin from apihub.security.helpers import hash_password @@ -29,7 +29,8 @@ class Meta: model = User id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") + name = factory.Sequence(lambda n: f"Mr. Tester{n}") + email = factory.Sequence(lambda n: f"tester{n}@tester.com") salt = SALT hashed_password = itemgetter(1)(hash_password("password", salt=SALT)) role = UserType.USER @@ -40,18 +41,18 @@ def test_user_create(db_session): query = UserQuery(db_session) query.create_user( user=UserCreate( - username="tester", + name="Mr. Tester", email="newuser@test.com", password="testpassword", role=UserType.USER, ) ) - user = query.get_user_by_username(username="tester") + user = query.get_user_by_email(email="newuser@test.com") assert user is not None another_user = query.get_user_by_id(user_id=user.id) - assert user.username == another_user.username + assert user.email == another_user.email @pytest.fixture(scope="function") @@ -76,56 +77,56 @@ def protected(Authorize: AuthJWT = Depends()): Authorize.jwt_required() user = Authorize.get_jwt_subject() - roles = Authorize.get_raw_jwt()["roles"] - return {"user": user, "roles": roles} + role = Authorize.get_raw_jwt()["role"] + return {"user": user, "role": role} @app.get("/admin") - def admin(username=Depends(require_admin)): - return username + def admin(email=Depends(require_admin)): + return email app.dependency_overrides[create_session] = _create_session UserFactory._meta.sqlalchemy_session = db_session UserFactory._meta.sqlalchemy_session_persistence = "commit" - UserFactory(username="tester", role=UserType.USER) - UserFactory(username="admin", role=UserType.ADMIN) - UserFactory(username="manager", role=UserType.MANAGER) - UserFactory(username="user", role=UserType.USER) - UserFactory(username="app", role=UserType.APP) + UserFactory(email="tester@test.com", role=UserType.USER) + UserFactory(email="admin@test.com", role=UserType.ADMIN) + UserFactory(email="publisher@test.com", role=UserType.PUBLISHER) + UserFactory(email="user@test.com", role=UserType.USER) + UserFactory(email="app@test.com", role=UserType.APP) yield TestClient(app) class TestAuthenticate: - def _make_auth_header(self, username, password): + def _make_auth_header(self, email, password): from base64 import b64encode - raw = b64encode(f"{username}:{password}".encode("ascii")).decode("ascii") + raw = b64encode(f"{email}:{password}".encode("ascii")).decode("ascii") return {"Authorization": f"Basic {raw}"} def test_authenticate_wrong_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("nosuchuser", "password"), + headers=self._make_auth_header("nosuchuser@test.com", "password"), ) assert response.status_code == 403 def test_authenticate_wrong_password(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "nosuchpassword"), + headers=self._make_auth_header("tester@test.com", "nosuchpassword"), ) assert response.status_code == 403 def test_authenticate(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "password"), + headers=self._make_auth_header("tester@test.com", "password"), params={"expires_days": 2}, ) assert response.status_code == 200 - assert AuthenticateResponse.parse_obj(response.json()) + assert SecurityToken.parse_obj(response.json()).access_token is not None def test_pretected_no_token(self, client): response = client.get( @@ -136,10 +137,10 @@ def test_pretected_no_token(self, client): def test_token(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("tester", "password"), + headers=self._make_auth_header("tester@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( "/protected", headers={"Authorization": f"Bearer {token}"} @@ -149,46 +150,40 @@ def test_token(self, client): def test_require_admin_when_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("manager", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) - assert response.status_code == 403 + assert response.status_code == 200 def test_require_admin_when_manager(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("publisher@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get("/admin", headers={"Authorization": f"Bearer {token}"}) - assert response.status_code == 200 + assert response.status_code == 403 def test_create_and_get_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token - response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) + response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 - assert response.json().get("username") == "admin" - - response = client.get( - "/user/user", headers={"Authorization": f"Bearer {token}"} - ) - assert response.status_code == 200 - assert response.json().get("username") == "user" + assert response.json().get("email") == "admin@test.com" new_user = UserCreate( - username="newuser", + name="New User", email="newuser@test.com", password="password", role="user", @@ -199,26 +194,34 @@ def test_create_and_get_user(self, client): json=new_user.dict(), ) assert response.status_code == 200 + + response = client.get( + "/_authenticate", + headers=self._make_auth_header("newuser@test.com", "password"), + ) + assert response.status_code == 200 + auth_response = SecurityToken.parse_obj(response.json()) + token = auth_response.access_token response = client.get( - f"/user/{new_user.username}", headers={"Authorization": f"Bearer {token}"} + f"/user", headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200 - assert response.json().get("username") == new_user.username + assert response.json().get("email") == new_user.email def test_get_users(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( "/user", headers={"Authorization": f"Bearer {token}"}, - json={"usernames": "admin,manager,user"}, + json={"emails": "admin@test.com,publisher@test.com,user@test.com"}, ) assert response.status_code == 200 assert len(response.json()) == 3 @@ -226,10 +229,10 @@ def test_get_users(self, client): def test_list_users(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( @@ -242,10 +245,10 @@ def test_list_users(self, client): def test_change_password_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.post( @@ -257,13 +260,13 @@ def test_change_password_user(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 403 response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "newpassword"), + headers=self._make_auth_header("user@test.com", "newpassword"), ) assert response.status_code == 200 @@ -277,20 +280,20 @@ def test_change_password_user(self, client): def test_change_password_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("admin", "password"), + headers=self._make_auth_header("admin@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 200 response = client.post( - "/user/user/_password", + "/user/user@test.com/_password", headers={"Authorization": f"Bearer {token}"}, json={"password": "newpassword"}, ) @@ -298,18 +301,18 @@ def test_change_password_admin(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "password"), + headers=self._make_auth_header("user@test.com", "password"), ) assert response.status_code == 403 response = client.get( "/_authenticate", - headers=self._make_auth_header("user", "newpassword"), + headers=self._make_auth_header("user@test.com", "newpassword"), ) assert response.status_code == 200 response = client.post( - "/user/user/_password", + "/user/user@test.com/_password", headers={"Authorization": f"Bearer {token}"}, json={"password": "password"}, ) @@ -318,14 +321,14 @@ def test_change_password_admin(self, client): def test_register(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("app", "password"), + headers=self._make_auth_header("app@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token new_user = UserRegister( - username="newuser", + name="New User", email="newuser@test.com", password="password", ) @@ -338,12 +341,12 @@ def test_register(self, client): response = client.get( "/_authenticate", - headers=self._make_auth_header("newuser", "password"), + headers=self._make_auth_header("newuser@test.com", "password"), ) assert response.status_code == 200 - auth_response = AuthenticateResponse.parse_obj(response.json()) + auth_response = SecurityToken.parse_obj(response.json()) token = auth_response.access_token - response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) + response = client.get("/user", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 - assert response.json().get("username") == new_user.username + assert response.json().get("email") == new_user.email diff --git a/tests/test_server.py b/tests/test_server.py index 5b715ff..92da18e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,8 +8,8 @@ from openapi_spec_validator import validate_spec, openapi_v30_spec_validator from apihub.common.db_session import create_session -from apihub.subscription.depends import require_subscription -from apihub.subscription.schemas import SubscriptionTier, SubscriptionBase +from apihub.subscription.depends import require_subscription, SubscriptionToken +from apihub.subscription.schemas import SubscriptionTier from apihub.utils import make_topic @@ -21,18 +21,12 @@ def _create_session(): def _ip_rate_limited(): pass - def _require_subscription(application:str): - return SubscriptionBase( - username="test", tier=SubscriptionTier.TRIAL, application=application, - ) monkeypatch.setenv("OUT_KIND", "MEM") from apihub.server import api, ip_rate_limited api.dependency_overrides[ip_rate_limited] = _ip_rate_limited - api.dependency_overrides[require_subscription] = _require_subscription - api.dependency_overrides[create_session] = _create_session yield TestClient(api) @@ -71,9 +65,15 @@ def get(self, application): monkeypatch.setattr( apihub.server, "get_definition_manager", _get_definition_manager ) + token = SubscriptionToken( + user_id=1, subscription_id=1, application_id=1, + email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", + role="user", name="user", expires_days=1, + ) response = client.post( - "/async/test", params={"text": "this is simple"}, json={"probability": 0.6} + "/async/test", params={"text": "this is simple"}, json={"probability": 0.6}, + headers={"Authorization": f"Bearer {token.access_token}"} ) assert response.status_code == 200 @@ -125,8 +125,15 @@ def get(self, application): apihub.server, "get_definition_manager", _get_definition_manager ) + token = SubscriptionToken( + user_id=1, subscription_id=1, application_id=1, + email="user@test.com", tier=SubscriptionTier.TRIAL, application="test", + role="user", name="user", expires_days=1, + ) + response = client.post( - "/async/test", params={}, json={"probability": 0.6} + "/async/test", params={}, json={"probability": 0.6}, + headers={"Authorization": f"Bearer {token.access_token}"} ) assert response.status_code == 422 diff --git a/tests/test_subscription.py b/tests/test_subscription.py index 723ac92..f634419 100644 --- a/tests/test_subscription.py +++ b/tests/test_subscription.py @@ -9,22 +9,23 @@ from apihub.common.db_session import create_session from apihub.security.models import User -from apihub.security.schemas import UserBase, UserType -from apihub.security.depends import require_user, require_admin, require_token +from apihub.security.schemas import UserBase, UserType, UserBaseWithId +from apihub.security.depends import require_user, require_admin, require_token, require_publisher, require_logged_in from apihub.subscription.depends import ( require_subscription_balance, + SubscriptionToken, ) from apihub.subscription.models import ( Subscription, SubscriptionTier, Application, - SubscriptionPricing, + Pricing, ) from apihub.subscription.router import router from apihub.subscription.schemas import ( SubscriptionIn, ApplicationCreate, - SubscriptionPricingBase, + PricingBase, ) from apihub.security.helpers import hash_password @@ -43,16 +44,19 @@ class Meta: url = factory.Sequence(lambda n: f"app/{n}") description = "description" + created_at = factory.LazyFunction(datetime.now) + user_id = factory.Sequence(int) + -class SubscriptionPricingFactory(factory.alchemy.SQLAlchemyModelFactory): +class PricingFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: - model = SubscriptionPricing + model = Pricing id = factory.Sequence(int) tier = SubscriptionTier.TRIAL price = 100.0 credit = 100.0 - application = factory.Sequence(lambda n: f"app{n}") + application_id = factory.Sequence(int) class UserFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -60,7 +64,8 @@ class Meta: model = User id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") + name = factory.Sequence(lambda n: f"Mr Tester{n}") + email = factory.Sequence(lambda n: f"tester{n}@test.com") salt = SALT hashed_password = itemgetter(1)(hash_password("password", salt=SALT)) role = UserType.USER @@ -72,9 +77,7 @@ class Meta: model = Subscription id = factory.Sequence(int) - username = factory.Sequence(lambda n: f"tester{n}") - application = "test" - active = True + is_active = True tier = SubscriptionTier.TRIAL credit = 100 balance = 0 @@ -82,9 +85,25 @@ class Meta: expires_at = factory.LazyFunction(lambda: datetime.now() + timedelta(days=1)) recurring = False created_at = factory.LazyFunction(datetime.now) - created_by = "admin" + # created_by = "admin" notes = None + user_id = factory.Sequence(int) + application_id = factory.Sequence(int) + pricing_id = factory.Sequence(int) + + +def _require_admin_token(): + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.ADMIN) + + +def _require_user_token(): + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.USER) + + +def _require_publisher_token(): + return UserBaseWithId(id=1, email="tester", name="tester", role=UserType.PUBLISHER) + @pytest.fixture(scope="function") def client(db_session): @@ -94,79 +113,65 @@ def _create_session(): finally: pass - def _require_admin(): - return "admin" - - def _require_user(): - return "user" - - def _require_admin_token(): - return UserBase(username="tester", role=UserType.ADMIN) - - def _require_user_token(): - return UserBase(username="tester", role=UserType.USER) - app = FastAPI() app.include_router(router) app.dependency_overrides[create_session] = _create_session - app.dependency_overrides[require_admin] = _require_admin - app.dependency_overrides[require_user] = _require_user + app.dependency_overrides[require_admin] = _require_admin_token + app.dependency_overrides[require_user] = _require_user_token + app.dependency_overrides[require_publisher] = _require_publisher_token app.dependency_overrides[require_token] = _require_user_token + app.dependency_overrides[require_logged_in] = _require_user_token @app.get("/api_balance/{application}") def api_function_2( - application: str, username: str = Depends(require_subscription_balance) + application: str, subscription: SubscriptionToken = Depends(require_subscription_balance) ): pass UserFactory._meta.sqlalchemy_session = db_session UserFactory._meta.sqlalchemy_session_persistence = "commit" - UserFactory(username="tester", role=UserType.USER) + tester = UserFactory(id=100, email="tester@test.com", role=UserType.USER) + + UserFactory._meta.sqlalchemy_session = db_session + UserFactory._meta.sqlalchemy_session_persistence = "commit" + publisher = UserFactory(id=200, email="publisher@test.com", role=UserType.PUBLISHER) ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - application = ApplicationFactory(name="test", url="/test") + application = ApplicationFactory(id=100, name="test", url="/test", user_id=publisher.id) - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - pricing = SubscriptionPricingFactory( + PricingFactory._meta.sqlalchemy_session = db_session + PricingFactory._meta.sqlalchemy_session_persistence = "commit" + pricing = PricingFactory( + id=100, tier=SubscriptionTier.TRIAL, price=100, credit=100, - application="test", + application_id=application.id, ) SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - - SubscriptionFactory(username="tester", application="test", credit=100) + SubscriptionFactory(user_id=tester.id, application_id=application.id, credit=100, pricing=pricing) yield TestClient(app) -def _require_admin_token(): - return UserBase(username="tester", role=UserType.ADMIN) - - -def _require_user_token(): - return UserBase(username="tester", role=UserType.USER) - - class TestApplication: def test_create_application(self, client): new_application = ApplicationCreate( name="app", url="/test", description="test", - pricing=[ - SubscriptionPricingBase( + pricings=[ + PricingBase( tier=SubscriptionTier.TRIAL, price=100, credit=100 ), - SubscriptionPricingBase( + PricingBase( tier=SubscriptionTier.STANDARD, price=200, credit=200 ), - SubscriptionPricingBase( + PricingBase( tier=SubscriptionTier.PREMIUM, price=300, credit=300 ), ], @@ -182,7 +187,7 @@ def test_create_application(self, client): ) response_json = response.json() - assert len(response_json["pricing"]) == 3 + assert len(response_json["pricings"]) == 3 def test_list_application(self, client, db_session): response = client.get("/application") @@ -197,30 +202,34 @@ def test_get_application(self, client, db_session): assert response.status_code == 200 response_json = response.json() assert ( - len(response_json["pricing"]) == 1 - and response_json["pricing"][0]["tier"] == "TRIAL" + len(response_json["pricings"]) == 1 + and response_json["pricings"][0]["tier"] == "TRIAL" ) class TestSubscription: def test_create_and_get_subscription(self, client, db_session): + UserFactory._meta.sqlalchemy_session = db_session + UserFactory._meta.sqlalchemy_session_persistence = "commit" + publisher = UserFactory(email="publisher1@test.com", role=UserType.PUBLISHER) + ApplicationFactory._meta.sqlalchemy_session = db_session ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="application", url="/test") + application = ApplicationFactory(name="application", url="/test", user_id=publisher.id) - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - SubscriptionPricingFactory( + PricingFactory._meta.sqlalchemy_session = db_session + PricingFactory._meta.sqlalchemy_session_persistence = "commit" + pricing = PricingFactory( tier=SubscriptionTier.TRIAL, price=100, credit=100, - application="application", ) # case 1: create subscription new_subscription = SubscriptionIn( - username="tester", - application="application", + user_id=publisher.id, + application_id=application.id, + pricing_id=pricing.id, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -231,14 +240,24 @@ def test_create_and_get_subscription(self, client, db_session): ) assert response.status_code == 200 + def _require_logged_in(): + return publisher + + client.app.dependency_overrides[require_logged_in] = _require_logged_in + response = client.get( - "/subscription/application", + f"/subscription/{application.id}", ) assert response.status_code == 200 assert response.json().get("credit") == 100 - assert response.json().get("active") is True def test_get_all_subscriptions(self, client): + + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) + + client.app.dependency_overrides[require_user] = _require_user + response = client.get( "/subscription", ) @@ -248,8 +267,9 @@ def test_get_all_subscriptions(self, client): def test_create_subscription_not_existing_user(self, client): new_subscription = SubscriptionIn( - username="not existing user", - application="app 1", + user_id=-1, + application_id=1, + pricing_id=1, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -261,61 +281,31 @@ def test_create_subscription_not_existing_user(self, client): assert response.status_code == 401 def test_get_application_token(self, client, db_session): - SubscriptionFactory._meta.sqlalchemy_session = db_session - SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - - ApplicationFactory(name="app") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app" - ) - SubscriptionFactory(username="tester", application="app", credit=100) - - response = client.get( - "/token/app", - ) - assert response.status_code == 200, response.json() - assert response.json().get("token") is not None - - ApplicationFactory(name="app_2") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app_2" - ) - SubscriptionFactory( - username="tester", application="app_2", active=False, credit=1000 - ) + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) - response = client.get( - "/subscription/app_2", - ) - - assert response.status_code == 400, response.json() + client.app.dependency_overrides[require_user] = _require_user - def test_get_application_token_admin(self, client, db_session): - client.app.dependency_overrides[require_token] = _require_admin_token SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="app2") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app2" + application = ApplicationFactory(name="app", user_id=100) + pricing = PricingFactory( + tier=SubscriptionTier.TRIAL, price=100, credit=100, application_id=application.id ) - SubscriptionFactory(username="tester", application="app2", credit=1000) + SubscriptionFactory(user_id=100, application_id=application.id, pricing_id=pricing.id, credit=100) response = client.get( - "/token/app2", - params={ - "username": "tester", - "expires_days": 30, - }, + "/token/app", ) - client.app.dependency_overrides[require_token] = _require_user_token assert response.status_code == 200, response.json() - assert response.json().get("token") is not None + assert response.json().get("access_token") is not None def test_create_duplicate_subscription(self, client, db_session): new_subscription = SubscriptionIn( - username="tester", - application="application", + user_id=100, + application_id=100, + pricing_id=100, tier=SubscriptionTier.TRIAL, expires_at=None, recurring=False, @@ -324,62 +314,24 @@ def test_create_duplicate_subscription(self, client, db_session): "/subscription", data=new_subscription.json(), ) - assert response.status_code == 404 + assert response.status_code == 403 - ApplicationFactory._meta.sqlalchemy_session = db_session - ApplicationFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="application", url="/test") - - SubscriptionPricingFactory._meta.sqlalchemy_session = db_session - SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit" - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, - price=100, - credit=100, - application="application", - ) - - new_subscription = SubscriptionIn( - username="tester", - application="application", - tier=SubscriptionTier.TRIAL, - expires_at=None, - recurring=False, - ) - response = client.post( - "/subscription", - data=new_subscription.json(), - ) - assert response.status_code == 200, response.json() + def test_require_balance(self, client, db_session): + def _require_user(): + return UserBaseWithId(id=100, email="", name="", role=UserType.USER) - response = client.post( - "/subscription", - data=new_subscription.json(), - ) - assert response.status_code == 403, response.json() + client.app.dependency_overrides[require_user] = _require_user - def test_require_balance(self, client, db_session): SubscriptionFactory._meta.sqlalchemy_session = db_session SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit" - ApplicationFactory(name="app3") - SubscriptionPricingFactory( - tier=SubscriptionTier.TRIAL, price=100, credit=100, application="app3" - ) - SubscriptionFactory( - username="tester", - application="app3", - tier=SubscriptionTier.TRIAL, - credit=2, - ) - response = client.get( - "/token/app3", + "/token/test", ) assert response.status_code == 200, response.json() - token = response.json().get("token") + token = response.json().get("access_token") response = client.get( - "/api_balance/app3", headers={"Authorization": f"Bearer {token}"} + "/api_balance/test", headers={"Authorization": f"Bearer {token}"} ) - assert response.status_code == 200, response.json() + assert response.status_code == 200, response.json() \ No newline at end of file