Skip to content

Commit

Permalink
Merge pull request #9 from boostcampaitech5/feat/backend_loadmodel
Browse files Browse the repository at this point in the history
[Feat] frontend/backend 초안
  • Loading branch information
Hyunmin-H authored Jul 17, 2023
2 parents f3d6c5c + 475983c commit d0b9b9b
Show file tree
Hide file tree
Showing 588 changed files with 272 additions and 164 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
backend/app/__pycache__/
Binary file removed backend/app/__pycache__/__main__.cpython-310.pyc
Binary file not shown.
Binary file removed backend/app/__pycache__/__main__.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed backend/app/__pycache__/main.cpython-310.pyc
Binary file not shown.
Binary file removed backend/app/__pycache__/main.cpython-38.pyc
Binary file not shown.
Binary file removed backend/app/__pycache__/model.cpython-310.pyc
Binary file not shown.
Binary file removed backend/app/__pycache__/model.cpython-38.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions backend/app/confirm_button_hack.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def evaluate(self):
if st.button(label):
cache_entry.evaluate()
else:
raise st.script_runner.StopException
st.stop()
return cache_entry.return_value

return wrapped_func

return function_decorator
return function_decorator
99 changes: 73 additions & 26 deletions backend/app/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,84 @@


def main():
st.title("Mask Classification Model")
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
st.title("Welcome to VTON World :)")

with st.container():
col1, col2, col3 = st.columns([1,1,1])

with col1:
st.header("Human")
uploaded_target = st.file_uploader("Choose an target image", type=["jpg", "jpeg", "png"])

if uploaded_target:
target_bytes = uploaded_target.getvalue()
target_img = Image.open(io.BytesIO(target_bytes))

if uploaded_file:
image_bytes = uploaded_file.getvalue()
image = Image.open(io.BytesIO(image_bytes))
st.image(target_img, caption='Uploaded target Image')

with col2:
st.header("Cloth")
uploaded_garment = st.file_uploader("Choose an garment image", type=["jpg", "jpeg", "png"])

st.image(image, caption='Uploaded Image')
st.write("Classifying...")
if uploaded_garment:
# st.spinner("dehazing now...")

garment_bytes = uploaded_garment.getvalue()
garment_img = Image.open(io.BytesIO(garment_bytes))

# 기존 stremalit 코드
# _, y_hat = get_prediction(model, image_bytes)
# label = config['classes'][y_hat.item()]
files = [
('files', (uploaded_file.name, image_bytes,
uploaded_file.type))
]
response = requests.post("http://localhost:8001/order", files=files)
label = response.json()["products"][0]["result"]
st.write(f'label is {label}')
st.image(garment_img, caption='Uploaded garment Image')

with col3:
st.header("Result")
if uploaded_target and uploaded_garment:
files = [
('files', (uploaded_target.name, target_bytes,
uploaded_target.type))
,
('files', (uploaded_garment.name, garment_bytes,
uploaded_garment.type))
]

with col3:
st.write(' ')
empty_slot = st.empty()
empty_slot.markdown("<h2 style='text-align: center;'>\nLoading...</h2>", unsafe_allow_html=True)

response = requests.post("http://localhost:8001/order", files=files)
empty_slot.empty()
empty_slot.markdown("<h2 style='text-align: center;'>Here it is !</h2>", unsafe_allow_html=True)

@cache_on_button_press('Authenticate')
def authenticate(password) -> bool:
return password == root_password
category = 'lower_body'
output_ladi_buffer_dir = '/opt/ml/user_db/ladi/buffer'
final_result_dir = output_ladi_buffer_dir
final_img = Image.open(os.path.join(final_result_dir, f'{category}.png'))

st.write(' ')
st.write(' ')
st.write(' ')
st.write(' ')
st.write(' ')
st.write(' ')
st.image(final_img, caption='Final Image', use_column_width=True)

# option = '선택 안 함'
# down_btn = st.download_button(
# label='Download Image',
# data=dehaze_image_bytes,
# file_name='dehazed_image.jpg',
# mime='image/jpg',
# on_click=save_btn_click(option, dehaze_image_bytes)
# )

# @cache_on_button_press('Authenticate')
# def authenticate(password) -> bool:
# return password == root_password
# password = st.text_input('password', type="password")
# if authenticate(password):
# st.success('You are authenticated!')
# main()
# else:
# st.error('The password is invalid.')

password = st.text_input('password', type="password")

if authenticate(password):
st.success('You are authenticated!')
main()
else:
st.error('The password is invalid.')
main()
75 changes: 70 additions & 5 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,30 @@
from datetime import datetime

from app.model import MyEfficientNet, get_model, get_config, predict_from_image_byte
from PIL import Image
import io

# scp setting
import sys, os
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/')
from simple_extractor import main_schp

# openpose
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/')
from extract_keypoint import main_openpose

# ladi
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton')
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src')
from inference import main_ladi



# sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model')
# print('sys.path:', sys.path)
# from Self_Correction_Human_Parsing.simple_extractor import main_schp
# from pytorch_openpose.extract_keypoint import main_openpose
# from ladi_vton.src.inference import main_ladi
app = FastAPI()

orders = []
Expand Down Expand Up @@ -71,17 +94,59 @@ def get_order_by_id(order_id: UUID) -> Optional[Order]:
return next((order for order in orders if order.id == order_id), None)


# post!!
@app.post("/order", description="주문을 요청합니다")
async def make_order(files: List[UploadFile] = File(...),
# def make_order(files: List[UploadFile] = File(...),
model: MyEfficientNet = Depends(get_model),
config: Dict[str, Any] = Depends(get_config)):
products = []
for file in files:
image_bytes = await file.read()
inference_result = predict_from_image_byte(model=model, image_bytes=image_bytes, config=config)
product = InferenceImageProduct(result=inference_result)
products.append(product)

# target:files[0], garment:files[1]

target_bytes = await files[0].read()
garment_bytes = await files[1].read()

# TODO image byte
target_image = Image.open(io.BytesIO(target_bytes))
target_image = target_image.convert("RGB")

garment_image = Image.open(io.BytesIO(garment_bytes))
garment_image = garment_image.convert("RGB")

input_dir = '/opt/ml/user_db/input/'

os.makedirs(f'{input_dir}/buffer', exist_ok=True)

target_image.save(f'{input_dir}/target.jpg')
target_image.save(f'{input_dir}/buffer/target/target.jpg')

garment_image.save(f'{input_dir}/garment.jpg')
garment_image.save(f'{input_dir}/buffer/garment/garment.jpg')

# schp - (1024, 784), (512, 384)
target_buffer_dir = f'{input_dir}/buffer/target'
main_schp(target_buffer_dir)


# openpose
output_openpose_buffer_dir = '/opt/ml/user_db/openpose/buffer'
os.makedirs(output_openpose_buffer_dir, exist_ok=True)
main_openpose(target_buffer_dir, output_openpose_buffer_dir)

# ladi-vton
output_ladi_buffer_dir = '/opt/ml/user_db/ladi/buffer'
db_dir = '/opt/ml/user_db'
os.makedirs(output_ladi_buffer_dir, exist_ok=True)
main_ladi(db_dir, output_ladi_buffer_dir)

return None
## return값
## output dir

inference_result = predict_from_image_byte(model=model, image_bytes=image_bytes, config=config)
product = InferenceImageProduct(result=inference_result)
products.append(product)

new_order = Order(products=products)
orders.append(new_order)
Expand Down
File renamed without changes.
Loading

0 comments on commit d0b9b9b

Please sign in to comment.