Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feedback #1

Open
wants to merge 9 commits into
base: feedback
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

# dependencies
/node_modules
/.pnp
.pnp.js

# testing
/coverage

# production
/build

# misc
.DS_Store
.env.local
.env.development.local
.env.test.local
.env.production.local

npm-debug.log*
yarn-debug.log*
yarn-error.log*
24 changes: 24 additions & 0 deletions Backend/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# common
*.pyc
*.csv
*.json
*.log
*.cfg
*.db
*.out
*.pth

# dir
Ignore_to_Push/
__pycache__/

# airflow
Airflow/logs/
airflow-webserver.pid
airflow.cfg
webserver_config.py
Airflow/dags/practice.py

# Backend
FastAPI/migrations
alembic.ini
78 changes: 78 additions & 0 deletions Backend/Airflow/dags/Import_Dataset_SQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from datetime import datetime, timedelta

from airflow import DAG
from airflow.models.variable import Variable

from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator

import os
import sys
sys.path.append(os.path.abspath('../'))

from utils.databases_import import get_CSV, import_dataset, get_citation_reference
from utils.paper_models_train import model_train_save

default_args = {
'depends_on_past' : True,
'owner' : 'dohyun',
'retries' : 3,
'retry_delay' : timedelta(minutes=5)
}

kaggle_api_key = Variable.get("kaggle_api_key")
kaggle_user = Variable.get("kaggle_username")

api_key = Variable.get("semantic_api_key") # Airflow - site - variables 추가
sql_user = Variable.get("sql_user")
sql_password = Variable.get("sql_password")
sql_port = Variable.get("sql_port")

with DAG(
dag_id = 'Import_Dataset_SQL',
description = 'Import SQL with Target_date',
start_date = datetime(2010, 1, 1),
schedule_interval = '@monthly',
default_args = default_args,
tags = ['my_dags'],
) as dag :

t1 = BashOperator(
task_id = 'Get_Json',
bash_command = f"""
export KAGGLE_USERNAME={kaggle_user}
export KAGGLE_KEY={kaggle_api_key}
kaggle datasets download Cornell-University/arxiv \-p /home/dohyun/Final_P/arxiv/ --unzip
""",

)

t2 = PythonOperator(
task_id = 'Get_Csv',
python_callable = get_CSV,
op_args = ['{{execution_date}}'],

)

t3 = PythonOperator(
task_id = 'Get_Citation_Reference',
python_callable = get_citation_reference,
op_args = ['{{execution_date}}', api_key],

)

t4 = PythonOperator(
task_id = 'Import_MySQL',
python_callable = import_dataset,
op_args = ['{{execution_date}}', sql_user, sql_password, sql_port],

)

t5 = PythonOperator(
task_id = 'Model_retraining',
python_callable = model_train_save,
op_args = ['{{execution_date}}', sql_user, sql_password, sql_port],

)

t1 >> t2 >> t3 >> t4 >> t5
12 changes: 12 additions & 0 deletions Backend/FastAPI/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

SQLALCHEMY_DATABASE_URL = 'mysql+pymysql://dohyun:Dhyoon96!@localhost:3306/final_project'

engine = create_engine(
SQLALCHEMY_DATABASE_URL,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()
100 changes: 100 additions & 0 deletions Backend/FastAPI/domain/chat_data/chat_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Annotated
from pydantic import BaseModel, Field
from datetime import datetime
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, HTTPException
from starlette import status
from models import Chat, Message, PaperInfo
from database import SessionLocal
from domain.login_data.login_data import get_current_user
from sqlalchemy import asc, desc;

router = APIRouter(
prefix='/chat',
tags=['chat']
)

def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

db_dependency = Annotated[Session, Depends(get_db)]
user_dependency = Annotated[dict, Depends(get_current_user)]

class ChatRequest(BaseModel):
paper_id : str

class MessageRequest(BaseModel):
content : str
paper_id : str
time : datetime = Field(default_factory=datetime.now)
user_com : bool # 사용자면 0 / chatgpt면 1

@router.get("/room", status_code=status.HTTP_200_OK)
async def read_all_chat(user: user_dependency, db: db_dependency):
if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

return db.query(Chat).filter(Chat.user_id == user.get('id')).all()

@router.get("/", status_code=status.HTTP_200_OK)
async def determine_chat(user: user_dependency, db: db_dependency, paper_id: str):
if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

Chat_model = db.query(Chat).filter(Chat.user_id == user.get('id')).filter(Chat.paper_id == paper_id).first()

if Chat_model:
return db.query(Message).filter(Chat_model.chat_id == Message.chat_id).order_by(asc(Message.time)).all()
else:
return False

@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_chat(user: user_dependency, db: db_dependency,
chat_request: ChatRequest):
if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

Chat_model = db.query(Chat).filter(Chat.user_id == user.get('id')).filter(Chat.paper_id == chat_request.paper_id).first()

if Chat_model is not None:
raise HTTPException(status_code=409, detail="Chatroom already exists")
else:
paper_title_model = db.query(PaperInfo).filter(PaperInfo.id == chat_request.paper_id)

chat_model = Chat(paper_id = chat_request.paper_id, user_id=user.get('id'), paper_title = paper_title_model[0].title)

db.add(chat_model)
db.commit()

# text, paper_id, user_com(bool) 0이면 / 사용자 1이면 챗봇
@router.post("/message", status_code=status.HTTP_201_CREATED)
async def create_message(user: user_dependency, db: db_dependency,
message_request: MessageRequest):
if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

chat_model = db.query(Chat).filter(Chat.paper_id == message_request.paper_id).first()
message_model = Message(content = message_request.content, chat_id = chat_model.chat_id, user_com = message_request.user_com)

db.add(message_model)
db.commit()

@router.delete("/", status_code=status.HTTP_200_OK)
async def delete_chat(user: user_dependency, db: db_dependency,
paper_id: str):
if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

Chat_model = db.query(Chat).filter(Chat.user_id == user.get('id')).filter(Chat.paper_id == paper_id).first()
if Chat_model is not None:
db.query(Message).filter(Message.chat_id == Chat_model.chat_id).delete()
db.delete(Chat_model)
db.commit()
else:
raise HTTPException(status_code=204, detail="NO_CONTENT")


70 changes: 70 additions & 0 deletions Backend/FastAPI/domain/get_data/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from fastapi import APIRouter, Depends, HTTPException

from database import SessionLocal
from models import PaperInfo

import sys
sys.path.append("/home/dohyun/Final_P/myapi")

import torch
import torch.nn as nn

from typing import Annotated
from starlette import status
from models import PaperInfo
from domain.login_data.login_data import get_current_user

from sqlalchemy import or_

from paper_models_get import get_models_id

########################################################################################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class RNN(nn.Module): # Define the RNN model

def __init__(self, input_size, hidden_size, num_layers, output_size):
super(RNN, self).__init__()
self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)

def forward(self, x):
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(device) # Move initial hidden state to device
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out

model_rnn = RNN(input_size=1, hidden_size=128, num_layers=1, output_size=1).to(device) # Move model to device

model_rnn.load_state_dict(torch.load('/home/dohyun/Final_P/paper_models/model_rnn_state_dict.pth', map_location=device))

router = APIRouter(
prefix="/api/data",
)
user_dependency = Annotated[dict, Depends(get_current_user)]

@router.get("/get_data/{user_question}", status_code=status.HTTP_200_OK)
async def get_data(user: user_dependency, user_question: str):
db = SessionLocal()

if user is None:
raise HTTPException(status_code=401, detail='Authentication Failed')

user_question = user_question.strip()

_data_list = db.query(PaperInfo).filter(
or_(
PaperInfo.title.like(f"%{user_question}%"),
PaperInfo.categories.like(f"%{user_question}%"),
)
).all()

sorted_ids = get_models_id(model_rnn, _data_list)
sorted_ids = sorted_ids.values.tolist()

print(sorted_ids) # 10개
matched_papers = db.query(PaperInfo).filter(PaperInfo.id.in_(sorted_ids)).all()

return matched_papers

Loading