Skip to content

Commit

Permalink
add comments, add expires_at to token
Browse files Browse the repository at this point in the history
  • Loading branch information
JackySu committed Mar 28, 2023
1 parent 65aa47c commit d07c564
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 18 deletions.
33 changes: 17 additions & 16 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@

DEFAULT_PAGE_SIZE = 30
MAX_PAGE_SIZE = 50
# jinja2 template for server side render html
# templates = Jinja2Templates(directory="templates")

# 这里是 tokenUrl,而不是 token_url,是为了和 OAuth2 规范统一
# tokenUrl 是为了指定 OpenAPI 前端登录时的接口地址
# OAuthPasswordBearer Authorization Header 的 Bearer 取出来,然后传给 tokenUrl
# it is tokenUrl instead of kebab-case token_url to unify with OAuth2 scheme
# tokenUrl is to specify OpenAPI route address of token endpoint for frontend login
# OAuthPasswordBearer extracts Bearer from Authorization Header and send it to tokenUrl
oauth2_bearer = OAuth2PasswordBearer(tokenUrl=API_PREFIX + "/user/token")

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Expand All @@ -51,6 +52,7 @@ async def startup():
await init_db()


# for further refactor the functions down below should be moved to a utils package
def is_admin(user: User) -> bool:
return True if user.role else False

Expand Down Expand Up @@ -86,6 +88,7 @@ def _decode_token(token: str) -> dict:
raise unauthorized_error("Could not validate credentials")


# warning: this function returns all the users without limit and offset
async def get_all_users_db(session: AsyncSession = Depends(
get_session)) -> List[User]:
result = await session.execute(select(User))
Expand All @@ -108,6 +111,7 @@ async def get_user_with_email_db(
# return [User(**user.__dict__) for user in users]


# get current user from token by decoding it
async def get_current_user(token: str = Depends(oauth2_bearer),
session: AsyncSession = Depends(
get_session)) -> User:
Expand Down Expand Up @@ -135,12 +139,7 @@ async def get_current_user(token: str = Depends(oauth2_bearer),
return user


@router.on_event("startup")
async def on_startup():
await init_db()


@router.post("/user/signup")
@router.post("/user/signup", response_model=Token)
async def create_user(user: UserSignup,
session: AsyncSession = Depends(get_session)):
if not user.email or not re.match(
Expand Down Expand Up @@ -176,7 +175,7 @@ async def create_user(user: UserSignup,
"role": 0
},
expires=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
response = {"access_token": token, "token_type": "bearer", "role": 0}
response = {"access_token": token, "token_type": "bearer", "role": 0, "expires_at": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)}
return response


Expand Down Expand Up @@ -287,7 +286,8 @@ async def login(form: OAuth2PasswordRequestForm = Depends(),
response = {
"access_token": token,
"token_type": "bearer",
"role": user.role
"role": user.role,
"expires_at": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
}
return response

Expand Down Expand Up @@ -330,11 +330,12 @@ async def admin_get_all_users(response: Response,


@router.get("/admin/projects", response_model=List[ProjectWithUserAndTags])
async def admin_get_all_projects(response: Response,
per_page: int = DEFAULT_PAGE_SIZE,
page: int = DEFAULT_PAGE,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user)):
async def admin_get_all_projects(
response: Response,
per_page: int = DEFAULT_PAGE_SIZE,
page: int = DEFAULT_PAGE,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user)):
if not current_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized")
Expand Down
4 changes: 4 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fastapi.middleware.cors import CORSMiddleware
import api


# Create FastAPI app and add CORS middleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
Expand All @@ -12,9 +14,11 @@
allow_headers=["*"],
)

# Add API routes
app.include_router(api.router)


# redirect to docs
@app.get("/")
async def index():
return RedirectResponse(url="/docs", status_code=302)
Expand Down
1 change: 1 addition & 0 deletions database/database.env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DATABASE_URL=sqlite+aiosqlite:///./database/database.db
8 changes: 6 additions & 2 deletions db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker


DATABASE_URL = 'sqlite+aiosqlite:///./database/database.db'
from dotenv import load_dotenv
# get the database url from the environment variable
load_dotenv("./database/database.env")
DATABASE_URL = os.getenv("DATABASE_URL")

engine = create_async_engine(DATABASE_URL, echo=True, future=True)


# initialize db and create tables if they don't exist
async def init_db():
async with engine.begin() as conn:
# await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)


# create a new session and return it when called
async def get_session() -> AsyncSession:
async_session = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
Expand Down
1 change: 1 addition & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Token(SQLModel):
access_token: str
token_type: str
role: int
expires_at: datetime


class CategoryBase(SQLModel):
Expand Down
Binary file modified requirements.txt
Binary file not shown.

0 comments on commit d07c564

Please sign in to comment.