diff --git a/backend/app/frontend.py b/backend/app/frontend.py
index 9d6a3cc..0077c2a 100644
--- a/backend/app/frontend.py
+++ b/backend/app/frontend.py
@@ -7,16 +7,58 @@
import streamlit as st
from app.confirm_button_hack import cache_on_button_press
+import base64
# SETTING PAGE CONFIG TO WIDE MODE
-ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets")
-
+ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets")
st.set_page_config(layout="wide")
-root_password = 'a'
+# st.set_page_config(layout="wide")
+
+root_password = 'a'
category_pair = {'Upper':'upper_body', 'Lower':'lower_body', 'Upper & Lower':'upper_lower', 'Dress':'dresses'}
+db_dir = '/opt/ml/user_db'
+
+def apply_custom_font(text, font_size=48):
+ # 글꼴 로드
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ parent_dir = os.path.dirname(current_dir)
+ font_filename = 'NanumSquareB.ttf'
+ font_path = os.path.join(parent_dir, "assets", font_filename)
+
+ try:
+ with open(font_path, "rb") as f:
+ font_data = f.read()
+ font_base64 = base64.b64encode(font_data).decode("utf-8")
+ font_style = f"font-family: 'CustomFont' ; font-size: {font_size}px;"
+ styled_text = f'
{text}
'
+ return f'{styled_text}'
+ except Exception as e:
+ st.error(f"Error loading the font: {e}")
+ return None
+
+def user_guideline_for_human():
+
+ st.write(' ')
+ text1 = """ 1. 전신 사진을 넣어주세요.
"""
+ text3 = """ 2. 입을 옷을 선택해주세요.
"""
+ text4 = """ 3. '옷 입히기 시작' 버튼을 눌러주세요.
"""
+
+ st.markdown(text1, unsafe_allow_html=True)
+ st.markdown(text3, unsafe_allow_html=True)
+ st.markdown(text4, unsafe_allow_html=True)
+
+def user_guideline_for_garment():
+
+ st.write(' ')
+ text1 = """ 1. 상의, 하의, 상의&하의, 드레스 카테고리를 선택해주세요.
""" #text-align: center;
+ text2 = """ 2. 단일 옷 사진을 넣어주세요.
"""
+
+ st.markdown(text1, unsafe_allow_html=True)
+ st.markdown(text2, unsafe_allow_html=True)
+
def check_modelLoading():
api_url = "http://localhost:8001/get_boolean"
is_modelLoading = True
@@ -29,89 +71,231 @@ def check_modelLoading():
pass
return is_modelLoading
+def read_image_as_bytes(image_path):
+ with open(image_path, "rb") as file:
+ image_data = file.read()
+ return image_data
+## 이미지 리스트에 저장
+def append_imgList(uploaded_garment, category):
+
+ garment_bytes = uploaded_garment.getvalue()
+ file = [
+ ('files', category),
+ ('files', (uploaded_garment.name, garment_bytes,
+ uploaded_garment.type))]
+ response = requests.post("http://localhost:8001/add_data", files=file)
+ response.raise_for_status() ## 200이 아니면 예외처리
+
+## 저장된 이미지 리스트들을 체크박스와 함께 띄우는 함수
+def show_garments_and_checkboxes(category):
+
+ category_dir = os.path.join(db_dir, 'input/garment', category)
+ filenames = os.listdir(category_dir)
+
+ num_columns = 3
+ num_rows = (len(filenames) - 1) // num_columns + 1
+
+ # 이미지들을 오른쪽으로 정렬하여 표시하기 위해 컬럼 생성
+ cols = st.columns(num_columns)
+ for i, filename in enumerate(filenames):
+ im_dir = os.path.join(category_dir, filename)
+ garment_img = Image.open(im_dir)
+ garment_byte = read_image_as_bytes(im_dir)
+ # st.image(garment_img, caption=filename[:-4], width=100)
+ cols[i % num_columns].image(garment_img, width=100, use_column_width=True, caption=filename[:-4])
+ # if st.checkbox(filename[:-4]) :
+ # return True, garment_byte
+ # else :
+ # return False, None
+ filenames_ = [None]
+ filenames_.extend([f[:-4] for f in filenames])
+ selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0)
+ print('selected_garment', selected_garment)
+
+ return filenames, selected_garment
+
+def md_style():
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True
+ )
+
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True
+ )
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True
+ )
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True
+ )
+
def main():
- st.title("Welcome to VTON World :)")
- is_all_uploaded = False
+ md_style()
+ # st.title("d") #🌳나만의 드레스룸🌳
+ st.markdown("🌳나만의 드레스룸🌳
", unsafe_allow_html=True)
+
with st.container():
col1, col2, col3 = st.columns([1,1,1])
-
+ files = [0, 0, 0, ('files', 0)]
+ is_selected_upper = False
+ is_selected_lower = False
+ is_selected_dress = False
+ gen_start = False
+
with col1:
- st.header("Human")
+ st.markdown("", unsafe_allow_html=True)
+ # user_guideline_for_garment()
+ category = 'upper_body'
+
+ uploaded_garment = st.file_uploader("추가할 상의를 넣어주세요.", type=["jpg", "jpeg", "png"])
+
+ if uploaded_garment :
+ append_imgList(uploaded_garment, category)
+
+ filenames, selected_upper = show_garments_and_checkboxes(category)
+ if selected_upper :
+ is_selected_upper = True
+ files[2] = ('files', f'{selected_upper}.jpg')
+ print('selected_upper', selected_upper)
+
+ with col3:
- # target_img = Image.open('/opt/ml/user_db/input/buffer/target/target.jpg')
- uploaded_target = st.file_uploader("Choose an target image", type=["jpg", "jpeg", "png"])
+ st.markdown("", unsafe_allow_html=True)
+ category = 'lower_body'
- if uploaded_target:
- target_bytes = uploaded_target.getvalue()
- target_img = Image.open(io.BytesIO(target_bytes))
+ uploaded_garment = st.file_uploader("추가할 하의를 넣어주세요.", type=["jpg", "jpeg", "png"])
+
+ if uploaded_garment :
+ append_imgList(uploaded_garment, category)
+
+ filenames, selected_lower = show_garments_and_checkboxes(category)
+ if selected_lower :
+ is_selected_lower = True
+ files[3] = ('files', f'{selected_lower}.jpg')
+
+ st.write(' ')
+ st.write(' ')
+ st.markdown("", unsafe_allow_html=True)
+ category = 'dresses'
+
+ uploaded_garment = st.file_uploader("추가할 드레스를 넣어주세요.", type=["jpg", "jpeg", "png"])
+
+ if uploaded_garment :
+ append_imgList(uploaded_garment, category)
+
+ filenames, selected_dress = show_garments_and_checkboxes(category)
+ if selected_dress :
+ is_selected_dress = True
+ files[2] = ('files', f'{selected_dress}.jpg')
+ print('is_selected_lower', is_selected_lower)
+ print('is_selected_dress', is_selected_dress)
+
- st.image(target_img, caption='Uploaded target Image')
-
with col2:
- st.header("Cloth")
+ st.markdown("", unsafe_allow_html=True)
- category_list = ['Upper', 'Lower', 'Upper & Lower', 'Dress']
- selected_category = st.selectbox('Choose an category of garment', category_list)
- # uploaded_garment = Image.open('/opt/ml/user_db/input/buffer/garment/garment.jpg')
+ uploaded_target = st.file_uploader("전신 사진을 넣어주세요.", type=["jpg", "jpeg", "png"])
+ user_guideline_for_human()
- category = category_pair[selected_category]
- print('**category:', category)
+ # start_button = st.markdown("", unsafe_allow_html=True)
+ start_button = st.button("옷 입히기 시작", use_container_width=True)
+
+ human_slot = st.empty()
+ if uploaded_target:
+ target_bytes = uploaded_target.getvalue()
+ target_img = Image.open(io.BytesIO(target_bytes))
- if selected_category == 'Upper & Lower':
- uploaded_garment1 = st.file_uploader("Choose an upper image", type=["jpg", "jpeg", "png"])
- uploaded_garment2 = st.file_uploader("Choose an lower image", type=["jpg", "jpeg", "png"])
+ human_slot.empty()
+ human_slot.image(target_img)
- col2_1, col2_2, = st.columns([1,1])
- with col2_1:
- if uploaded_garment1:
- garment_bytes1 = uploaded_garment1.getvalue()
- garment_img1 = Image.open(io.BytesIO(garment_bytes1))
- st.image(garment_img1, caption='Uploaded upper Image')
+ # else :
- with col2_2:
- if uploaded_garment2:
- garment_bytes2 = uploaded_garment2.getvalue()
- garment_img2 = Image.open(io.BytesIO(garment_bytes2))
- st.image(garment_img2, caption='Uploaded lower Image')
-
- if uploaded_target and uploaded_garment1 and uploaded_garment2:
- is_all_uploaded = True
- files = [
- ('files', category),
- ('files', (uploaded_target.name, target_bytes,
- uploaded_target.type)),
- ('files', (uploaded_garment1.name, garment_bytes1,
- uploaded_garment1.type)),
- ('files', (uploaded_garment2.name, garment_bytes2,
- uploaded_garment2.type))
- ]
-
-
- else :
- uploaded_garment = st.file_uploader("Choose an garment image", type=["jpg", "jpeg", "png"])
-
- if uploaded_garment:
- garment_bytes = uploaded_garment.getvalue()
- garment_img = Image.open(io.BytesIO(garment_bytes))
- st.image(garment_img, caption='Uploaded garment Image')
-
- if uploaded_target and uploaded_garment :
- is_all_uploaded = True
- files = [
- ('files', category),
- ('files', (uploaded_target.name, target_bytes,
- uploaded_target.type)),
- ('files', (uploaded_garment.name, garment_bytes,
- uploaded_garment.type)),
- ]
+ # example_img = Image.open('/opt/ml/level3_cv_finalproject-cv-12/backend/app/utils/example.jpg')
+ # human_slot.image(example_img, width=300, use_column_width=True, caption='Example of target image')
- with col3:
- st.header("Result")
+ print('start_button', start_button)
+ if start_button and uploaded_target:
+ if is_selected_upper and is_selected_lower :
+ gen_start = True
+ category = 'upper_lower'
+
+ elif is_selected_upper :
+ print('catogory upperrr')
+ gen_start = True
+ category = 'upper_body'
+ elif is_selected_lower :
+ gen_start = True
+ category = 'lower_body'
+ files[2] = files[3] ## lower가 3index에 저장됨, upper&lower가 아닐 경우엔 2로 저장
+ elif is_selected_dress :
+ gen_start = True
+ category = 'dresses'
+ else :
+ gen_start = False
+ files[0] = ('files', category)
+ files[1] = ('files', (uploaded_target.name, target_bytes,
+ uploaded_target.type))
+ print('category', category)
+ print('files2', files[2])
+ print('files3', files[3])
- if is_all_uploaded:
-
- with col3:
+ if gen_start :
+
st.write(' ')
empty_slot = st.empty()
empty_slot.markdown("\nLoading...
", unsafe_allow_html=True)
@@ -141,15 +325,13 @@ def main():
st.write(' ')
st.write(' ')
st.write(' ')
- st.image(final_img, caption='Final Image', use_column_width=True)
- # option = '선택 안 함'
- # down_btn = st.download_button(
- # label='Download Image',
- # data=dehaze_image_bytes,
- # file_name='dehazed_image.jpg',
- # mime='image/jpg',
- # on_click=save_btn_click(option, dehaze_image_bytes)
- # )
+ human_slot.empty()
+ human_slot.image(final_img, caption='Final Image', use_column_width=True)
+
+
+ is_selected_upper = False
+ is_selected_lower = False
+ is_selected_dress = False
main()
\ No newline at end of file
diff --git a/backend/app/main.py b/backend/app/main.py
index b2ad673..ad8eba6 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -4,7 +4,6 @@
from uuid import UUID, uuid4
from typing import List, Union, Optional, Dict, Any
from datetime import datetime
-from app.model import MyEfficientNet, get_model, get_config, predict_from_image_byte
from PIL import Image
import io
@@ -42,6 +41,39 @@
app = FastAPI()
ladi_models = None
+db_dir = '/opt/ml/user_db'
+
+@app.post("/add_data", description="데이터 저장")
+async def add_garment_to_db(files: List[UploadFile] = File(...)):
+ byte_string = await files[0].read() ##await
+ string_io = io.BytesIO(byte_string)
+ category = string_io.read().decode('utf-8')
+
+ print('category in main', category)
+
+ garment_bytes = await files[1].read() ##await
+ garment_name = files[1].filename
+ print('!!!! garment_name', garment_name)
+ garment_image = Image.open(io.BytesIO(garment_bytes))
+ garment_image = garment_image.convert("RGB")
+
+ garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}'))
+
+def read_image_as_bytes(image_path):
+ with open(image_path, "rb") as file:
+ image_data = file.read()
+ return image_data
+
+@app.get("/get_db/{category}")
+async def get_DB(category: str) :
+ category_dir = os.path.join(db_dir, 'input/garment', category)
+ garment_db_bytes = {}
+ for filename in os.listdir(category_dir):
+ garment_id = filename[:-4]
+ garment_byte = read_image_as_bytes(os.path.join(category_dir, filename))
+ garment_db_bytes[garment_id] = garment_byte
+ return garment_db_bytes
+
def load_ladiModels():
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-inpainting"
@@ -182,12 +214,8 @@ def inference_ladi(category, db_dir, target_name='target.jpg'):
# post!!
@app.post("/order", description="주문을 요청합니다")
-async def make_order(
- files: List[UploadFile] = File(...),
- model: MyEfficientNet = Depends(get_model),
- config: Dict[str, Any] = Depends(get_config)):
+async def make_order(files: List[UploadFile] = File(...)):
- db_dir = '/opt/ml/user_db'
input_dir = '/opt/ml/user_db/input/'
# category : files[0], target:files[1], garment:files[2]
@@ -203,23 +231,37 @@ async def make_order(
os.makedirs(f'{input_dir}/buffer', exist_ok=True)
- target_image.save(f'{input_dir}/target.jpg')
+ # target_image.save(f'{input_dir}/target.jpg')
target_image.save(f'{input_dir}/buffer/target/target.jpg')
if category == 'upper_lower':
- garment_upper_bytes = await files[2].read()
- garment_lower_bytes = await files[3].read()
+ # garment_upper_bytes = await files[2].read()
+ # garment_lower_bytes = await files[3].read()
- garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes))
- garment_upper_image = garment_upper_image.convert("RGB")
- garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes))
- garment_lower_image = garment_lower_image.convert("RGB")
-
+ # garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes))
+ # garment_upper_image = garment_upper_image.convert("RGB")
+ # garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes))
+ # garment_lower_image = garment_lower_image.convert("RGB")
+
+ # # garment_upper_image.save(f'{input_dir}/upper_body.jpg')
+ # garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg')
+ # # garment_lower_image.save(f'{input_dir}/lower_body.jpg')
+ # garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg')
+
+
+ ## string으로 전송됐을 때(filename)
+ string_upper_bytes = await files[2].read()
+ string_lower_bytes = await files[3].read()
+ string_io_upper = io.BytesIO(string_upper_bytes)
+ string_io_lower = io.BytesIO(string_lower_bytes)
+ filename_upper = string_io_upper.read().decode('utf-8')
+ filename_lower = string_io_lower.read().decode('utf-8')
+
+ garment_image_upper = Image.open(os.path.join(db_dir, 'input/garment', 'upper_body', filename_upper))
+ garment_image_lower = Image.open(os.path.join(db_dir, 'input/garment', 'lower_body', filename_lower))
+ garment_image_upper.save(f'{input_dir}/buffer/garment/upper_body.jpg')
+ garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg')
- garment_upper_image.save(f'{input_dir}/upper_body.jpg')
- garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg')
- garment_lower_image.save(f'{input_dir}/lower_body.jpg')
- garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg')
inference_allModels('upper_body', db_dir)
shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg')
@@ -227,27 +269,28 @@ async def make_order(
else :
- garment_bytes = await files[2].read()
-
- garment_image = Image.open(io.BytesIO(garment_bytes))
- garment_image = garment_image.convert("RGB")
-
- garment_image.save(f'{input_dir}/{category}.jpg')
+ ## file로 전송됐을 때
+ # garment_bytes = await files[2].read()
+ # garment_image = Image.open(io.BytesIO(garment_bytes))
+ # garment_image = garment_image.convert("RGB")
+ # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')
+
+ ## string으로 전송됐을 때(filename)
+ byte_string = await files[2].read()
+ string_io = io.BytesIO(byte_string)
+ filename = string_io.read().decode('utf-8')
+
+ garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename))
garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')
+
inference_allModels(category, db_dir)
return None
## return값
## output dir
- inference_result = predict_from_image_byte(model=model, image_bytes=image_bytes, config=config)
- product = InferenceImageProduct(result=inference_result)
- products.append(product)
- new_order = Order(products=products)
- orders.append(new_order)
- return new_order
def update_order_by_id(order_id: UUID, order_update: OrderUpdate) -> Optional[Order]:
diff --git a/backend/app/model.py b/backend/app/model.py
index 5a8c6a9..7c879ef 100644
--- a/backend/app/model.py
+++ b/backend/app/model.py
@@ -9,7 +9,6 @@
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
-from efficientnet_pytorch import EfficientNet
class MyEfficientNet(nn.Module):
diff --git a/backend/assets/NanumSquareB.ttf b/backend/assets/NanumSquareB.ttf
new file mode 100644
index 0000000..7711cac
Binary files /dev/null and b/backend/assets/NanumSquareB.ttf differ