diff --git a/backend/app/frontend.py b/backend/app/frontend.py index 9d6a3cc..0077c2a 100644 --- a/backend/app/frontend.py +++ b/backend/app/frontend.py @@ -7,16 +7,58 @@ import streamlit as st from app.confirm_button_hack import cache_on_button_press +import base64 # SETTING PAGE CONFIG TO WIDE MODE -ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets") - +ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets") st.set_page_config(layout="wide") -root_password = 'a' +# st.set_page_config(layout="wide") + +root_password = 'a' category_pair = {'Upper':'upper_body', 'Lower':'lower_body', 'Upper & Lower':'upper_lower', 'Dress':'dresses'} +db_dir = '/opt/ml/user_db' + +def apply_custom_font(text, font_size=48): + # 글꼴 로드 + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + font_filename = 'NanumSquareB.ttf' + font_path = os.path.join(parent_dir, "assets", font_filename) + + try: + with open(font_path, "rb") as f: + font_data = f.read() + font_base64 = base64.b64encode(font_data).decode("utf-8") + font_style = f"font-family: 'CustomFont' ; font-size: {font_size}px;" + styled_text = f'
{text}
' + return f'{styled_text}' + except Exception as e: + st.error(f"Error loading the font: {e}") + return None + +def user_guideline_for_human(): + + st.write(' ') + text1 = """
1. 전신 사진을 넣어주세요.
""" + text3 = """
2. 입을 옷을 선택해주세요.
""" + text4 = """
3. '옷 입히기 시작' 버튼을 눌러주세요.
""" + + st.markdown(text1, unsafe_allow_html=True) + st.markdown(text3, unsafe_allow_html=True) + st.markdown(text4, unsafe_allow_html=True) + +def user_guideline_for_garment(): + + st.write(' ') + text1 = """
1. 상의, 하의, 상의&하의, 드레스 카테고리를 선택해주세요.
""" #text-align: center; + text2 = """
2. 단일 옷 사진을 넣어주세요.
""" + + st.markdown(text1, unsafe_allow_html=True) + st.markdown(text2, unsafe_allow_html=True) + def check_modelLoading(): api_url = "http://localhost:8001/get_boolean" is_modelLoading = True @@ -29,89 +71,231 @@ def check_modelLoading(): pass return is_modelLoading +def read_image_as_bytes(image_path): + with open(image_path, "rb") as file: + image_data = file.read() + return image_data +## 이미지 리스트에 저장 +def append_imgList(uploaded_garment, category): + + garment_bytes = uploaded_garment.getvalue() + file = [ + ('files', category), + ('files', (uploaded_garment.name, garment_bytes, + uploaded_garment.type))] + response = requests.post("http://localhost:8001/add_data", files=file) + response.raise_for_status() ## 200이 아니면 예외처리 + +## 저장된 이미지 리스트들을 체크박스와 함께 띄우는 함수 +def show_garments_and_checkboxes(category): + + category_dir = os.path.join(db_dir, 'input/garment', category) + filenames = os.listdir(category_dir) + + num_columns = 3 + num_rows = (len(filenames) - 1) // num_columns + 1 + + # 이미지들을 오른쪽으로 정렬하여 표시하기 위해 컬럼 생성 + cols = st.columns(num_columns) + for i, filename in enumerate(filenames): + im_dir = os.path.join(category_dir, filename) + garment_img = Image.open(im_dir) + garment_byte = read_image_as_bytes(im_dir) + # st.image(garment_img, caption=filename[:-4], width=100) + cols[i % num_columns].image(garment_img, width=100, use_column_width=True, caption=filename[:-4]) + # if st.checkbox(filename[:-4]) : + # return True, garment_byte + # else : + # return False, None + filenames_ = [None] + filenames_.extend([f[:-4] for f in filenames]) + selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0) + print('selected_garment', selected_garment) + + return filenames, selected_garment + +def md_style(): + st.markdown( + """ + + """, + unsafe_allow_html=True + ) + + st.markdown( + """ + + """, + unsafe_allow_html=True + ) + st.markdown( + """ + + """, + unsafe_allow_html=True + ) + st.markdown( + """ + + """, + unsafe_allow_html=True + ) + def main(): - st.title("Welcome to VTON World :)") - is_all_uploaded = False + md_style() + # st.title("d") #🌳나만의 드레스룸🌳 + st.markdown("

🌳나만의 드레스룸🌳

", unsafe_allow_html=True) + with st.container(): col1, col2, col3 = st.columns([1,1,1]) - + files = [0, 0, 0, ('files', 0)] + is_selected_upper = False + is_selected_lower = False + is_selected_dress = False + gen_start = False + with col1: - st.header("Human") + st.markdown("

상의👚

", unsafe_allow_html=True) + # user_guideline_for_garment() + category = 'upper_body' + + uploaded_garment = st.file_uploader("추가할 상의를 넣어주세요.", type=["jpg", "jpeg", "png"]) + + if uploaded_garment : + append_imgList(uploaded_garment, category) + + filenames, selected_upper = show_garments_and_checkboxes(category) + if selected_upper : + is_selected_upper = True + files[2] = ('files', f'{selected_upper}.jpg') + print('selected_upper', selected_upper) + + with col3: - # target_img = Image.open('/opt/ml/user_db/input/buffer/target/target.jpg') - uploaded_target = st.file_uploader("Choose an target image", type=["jpg", "jpeg", "png"]) + st.markdown("

하의👖

", unsafe_allow_html=True) + category = 'lower_body' - if uploaded_target: - target_bytes = uploaded_target.getvalue() - target_img = Image.open(io.BytesIO(target_bytes)) + uploaded_garment = st.file_uploader("추가할 하의를 넣어주세요.", type=["jpg", "jpeg", "png"]) + + if uploaded_garment : + append_imgList(uploaded_garment, category) + + filenames, selected_lower = show_garments_and_checkboxes(category) + if selected_lower : + is_selected_lower = True + files[3] = ('files', f'{selected_lower}.jpg') + + st.write(' ') + st.write(' ') + st.markdown("

드레스👗

", unsafe_allow_html=True) + category = 'dresses' + + uploaded_garment = st.file_uploader("추가할 드레스를 넣어주세요.", type=["jpg", "jpeg", "png"]) + + if uploaded_garment : + append_imgList(uploaded_garment, category) + + filenames, selected_dress = show_garments_and_checkboxes(category) + if selected_dress : + is_selected_dress = True + files[2] = ('files', f'{selected_dress}.jpg') + print('is_selected_lower', is_selected_lower) + print('is_selected_dress', is_selected_dress) + - st.image(target_img, caption='Uploaded target Image') - with col2: - st.header("Cloth") + st.markdown("

드레스룸🚪

", unsafe_allow_html=True) - category_list = ['Upper', 'Lower', 'Upper & Lower', 'Dress'] - selected_category = st.selectbox('Choose an category of garment', category_list) - # uploaded_garment = Image.open('/opt/ml/user_db/input/buffer/garment/garment.jpg') + uploaded_target = st.file_uploader("전신 사진을 넣어주세요.", type=["jpg", "jpeg", "png"]) + user_guideline_for_human() - category = category_pair[selected_category] - print('**category:', category) + # start_button = st.markdown("
", unsafe_allow_html=True) + start_button = st.button("옷 입히기 시작", use_container_width=True) + + human_slot = st.empty() + if uploaded_target: + target_bytes = uploaded_target.getvalue() + target_img = Image.open(io.BytesIO(target_bytes)) - if selected_category == 'Upper & Lower': - uploaded_garment1 = st.file_uploader("Choose an upper image", type=["jpg", "jpeg", "png"]) - uploaded_garment2 = st.file_uploader("Choose an lower image", type=["jpg", "jpeg", "png"]) + human_slot.empty() + human_slot.image(target_img) - col2_1, col2_2, = st.columns([1,1]) - with col2_1: - if uploaded_garment1: - garment_bytes1 = uploaded_garment1.getvalue() - garment_img1 = Image.open(io.BytesIO(garment_bytes1)) - st.image(garment_img1, caption='Uploaded upper Image') + # else : - with col2_2: - if uploaded_garment2: - garment_bytes2 = uploaded_garment2.getvalue() - garment_img2 = Image.open(io.BytesIO(garment_bytes2)) - st.image(garment_img2, caption='Uploaded lower Image') - - if uploaded_target and uploaded_garment1 and uploaded_garment2: - is_all_uploaded = True - files = [ - ('files', category), - ('files', (uploaded_target.name, target_bytes, - uploaded_target.type)), - ('files', (uploaded_garment1.name, garment_bytes1, - uploaded_garment1.type)), - ('files', (uploaded_garment2.name, garment_bytes2, - uploaded_garment2.type)) - ] - - - else : - uploaded_garment = st.file_uploader("Choose an garment image", type=["jpg", "jpeg", "png"]) - - if uploaded_garment: - garment_bytes = uploaded_garment.getvalue() - garment_img = Image.open(io.BytesIO(garment_bytes)) - st.image(garment_img, caption='Uploaded garment Image') - - if uploaded_target and uploaded_garment : - is_all_uploaded = True - files = [ - ('files', category), - ('files', (uploaded_target.name, target_bytes, - uploaded_target.type)), - ('files', (uploaded_garment.name, garment_bytes, - uploaded_garment.type)), - ] + # example_img = Image.open('/opt/ml/level3_cv_finalproject-cv-12/backend/app/utils/example.jpg') + # human_slot.image(example_img, width=300, use_column_width=True, caption='Example of target image') - with col3: - st.header("Result") + print('start_button', start_button) + if start_button and uploaded_target: + if is_selected_upper and is_selected_lower : + gen_start = True + category = 'upper_lower' + + elif is_selected_upper : + print('catogory upperrr') + gen_start = True + category = 'upper_body' + elif is_selected_lower : + gen_start = True + category = 'lower_body' + files[2] = files[3] ## lower가 3index에 저장됨, upper&lower가 아닐 경우엔 2로 저장 + elif is_selected_dress : + gen_start = True + category = 'dresses' + else : + gen_start = False + files[0] = ('files', category) + files[1] = ('files', (uploaded_target.name, target_bytes, + uploaded_target.type)) + print('category', category) + print('files2', files[2]) + print('files3', files[3]) - if is_all_uploaded: - - with col3: + if gen_start : + st.write(' ') empty_slot = st.empty() empty_slot.markdown("

\nLoading...

", unsafe_allow_html=True) @@ -141,15 +325,13 @@ def main(): st.write(' ') st.write(' ') st.write(' ') - st.image(final_img, caption='Final Image', use_column_width=True) - # option = '선택 안 함' - # down_btn = st.download_button( - # label='Download Image', - # data=dehaze_image_bytes, - # file_name='dehazed_image.jpg', - # mime='image/jpg', - # on_click=save_btn_click(option, dehaze_image_bytes) - # ) + human_slot.empty() + human_slot.image(final_img, caption='Final Image', use_column_width=True) + + + is_selected_upper = False + is_selected_lower = False + is_selected_dress = False main() \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py index b2ad673..ad8eba6 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,7 +4,6 @@ from uuid import UUID, uuid4 from typing import List, Union, Optional, Dict, Any from datetime import datetime -from app.model import MyEfficientNet, get_model, get_config, predict_from_image_byte from PIL import Image import io @@ -42,6 +41,39 @@ app = FastAPI() ladi_models = None +db_dir = '/opt/ml/user_db' + +@app.post("/add_data", description="데이터 저장") +async def add_garment_to_db(files: List[UploadFile] = File(...)): + byte_string = await files[0].read() ##await + string_io = io.BytesIO(byte_string) + category = string_io.read().decode('utf-8') + + print('category in main', category) + + garment_bytes = await files[1].read() ##await + garment_name = files[1].filename + print('!!!! garment_name', garment_name) + garment_image = Image.open(io.BytesIO(garment_bytes)) + garment_image = garment_image.convert("RGB") + + garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}')) + +def read_image_as_bytes(image_path): + with open(image_path, "rb") as file: + image_data = file.read() + return image_data + +@app.get("/get_db/{category}") +async def get_DB(category: str) : + category_dir = os.path.join(db_dir, 'input/garment', category) + garment_db_bytes = {} + for filename in os.listdir(category_dir): + garment_id = filename[:-4] + garment_byte = read_image_as_bytes(os.path.join(category_dir, filename)) + garment_db_bytes[garment_id] = garment_byte + return garment_db_bytes + def load_ladiModels(): pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-inpainting" @@ -182,12 +214,8 @@ def inference_ladi(category, db_dir, target_name='target.jpg'): # post!! @app.post("/order", description="주문을 요청합니다") -async def make_order( - files: List[UploadFile] = File(...), - model: MyEfficientNet = Depends(get_model), - config: Dict[str, Any] = Depends(get_config)): +async def make_order(files: List[UploadFile] = File(...)): - db_dir = '/opt/ml/user_db' input_dir = '/opt/ml/user_db/input/' # category : files[0], target:files[1], garment:files[2] @@ -203,23 +231,37 @@ async def make_order( os.makedirs(f'{input_dir}/buffer', exist_ok=True) - target_image.save(f'{input_dir}/target.jpg') + # target_image.save(f'{input_dir}/target.jpg') target_image.save(f'{input_dir}/buffer/target/target.jpg') if category == 'upper_lower': - garment_upper_bytes = await files[2].read() - garment_lower_bytes = await files[3].read() + # garment_upper_bytes = await files[2].read() + # garment_lower_bytes = await files[3].read() - garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes)) - garment_upper_image = garment_upper_image.convert("RGB") - garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes)) - garment_lower_image = garment_lower_image.convert("RGB") - + # garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes)) + # garment_upper_image = garment_upper_image.convert("RGB") + # garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes)) + # garment_lower_image = garment_lower_image.convert("RGB") + + # # garment_upper_image.save(f'{input_dir}/upper_body.jpg') + # garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg') + # # garment_lower_image.save(f'{input_dir}/lower_body.jpg') + # garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg') + + + ## string으로 전송됐을 때(filename) + string_upper_bytes = await files[2].read() + string_lower_bytes = await files[3].read() + string_io_upper = io.BytesIO(string_upper_bytes) + string_io_lower = io.BytesIO(string_lower_bytes) + filename_upper = string_io_upper.read().decode('utf-8') + filename_lower = string_io_lower.read().decode('utf-8') + + garment_image_upper = Image.open(os.path.join(db_dir, 'input/garment', 'upper_body', filename_upper)) + garment_image_lower = Image.open(os.path.join(db_dir, 'input/garment', 'lower_body', filename_lower)) + garment_image_upper.save(f'{input_dir}/buffer/garment/upper_body.jpg') + garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg') - garment_upper_image.save(f'{input_dir}/upper_body.jpg') - garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg') - garment_lower_image.save(f'{input_dir}/lower_body.jpg') - garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg') inference_allModels('upper_body', db_dir) shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg') @@ -227,27 +269,28 @@ async def make_order( else : - garment_bytes = await files[2].read() - - garment_image = Image.open(io.BytesIO(garment_bytes)) - garment_image = garment_image.convert("RGB") - - garment_image.save(f'{input_dir}/{category}.jpg') + ## file로 전송됐을 때 + # garment_bytes = await files[2].read() + # garment_image = Image.open(io.BytesIO(garment_bytes)) + # garment_image = garment_image.convert("RGB") + # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') + + ## string으로 전송됐을 때(filename) + byte_string = await files[2].read() + string_io = io.BytesIO(byte_string) + filename = string_io.read().decode('utf-8') + + garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename)) garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') + inference_allModels(category, db_dir) return None ## return값 ## output dir - inference_result = predict_from_image_byte(model=model, image_bytes=image_bytes, config=config) - product = InferenceImageProduct(result=inference_result) - products.append(product) - new_order = Order(products=products) - orders.append(new_order) - return new_order def update_order_by_id(order_id: UUID, order_update: OrderUpdate) -> Optional[Order]: diff --git a/backend/app/model.py b/backend/app/model.py index 5a8c6a9..7c879ef 100644 --- a/backend/app/model.py +++ b/backend/app/model.py @@ -9,7 +9,6 @@ import torch.nn as nn import torch.nn.functional as F from PIL import Image -from efficientnet_pytorch import EfficientNet class MyEfficientNet(nn.Module): diff --git a/backend/assets/NanumSquareB.ttf b/backend/assets/NanumSquareB.ttf new file mode 100644 index 0000000..7711cac Binary files /dev/null and b/backend/assets/NanumSquareB.ttf differ