diff --git a/main.py b/main.py index 3aebb074e3..68ad4b9437 100644 --- a/main.py +++ b/main.py @@ -163,7 +163,7 @@ def feature_card_cosine_similarity(card1, card2): return result[0][1].item() -async def fetch_data(user_id, card_type): +async def fetch_data(user_id, card_type, want_to_find): location_cluster = [{'강서구', '양천구'}, {'구로구', '영등포구', '금천구'}, {'동작구', '관악구'}, {'서초구', '강남구'}, {'송파구', '강동구'}, {'은평구', '서대문구', '마포구'}, {'종로구', '중구', '용산구'}, {'중랑구', '동대문구', '성동구', '광진구'}, @@ -187,24 +187,31 @@ async def fetch_data(user_id, card_type): cluster_index = i break - user_cards = [] - post_cards = [] - for location in location_cluster[cluster_index]: - query = f""" - SELECT member_id AS id, location, member_features::jsonb AS features, gender, 'my' AS card_type, birth_year - FROM member_account - JOIN feature_card - ON member_account.my_card_id = feature_card.feature_card_id - WHERE location like '%{location}%' and gender IN ('{user_gender.lower()}', '{user_gender.upper()}') - UNION ALL - SELECT member_id as id, location, member_features::jsonb AS features, gender, 'mate' AS card_type, birth_year - FROM member_account - JOIN feature_card - ON member_account.mate_card_id = feature_card.feature_card_id - WHERE location like '%{location}%' and gender IN ('{user_gender.lower()}', '{user_gender.upper()}') - """ - user_cards.extend([dict(record) for record in await database.fetch_all(query)]) - + cards = [] + + if want_to_find == 'member': + for location in location_cluster[cluster_index]: + if card_type == 'my': + query = f""" + SELECT member_id as id, location, member_features::jsonb AS features, gender, 'mate' AS card_type, birth_year + FROM member_account + JOIN feature_card + ON member_account.mate_card_id = feature_card.feature_card_id + WHERE location like '%{location}%' and gender IN ('{user_gender.lower()}', '{user_gender.upper()}') + """ + cards.extend([dict(record) for record in await database.fetch_all(query)]) + + elif card_type == 'mate': + query = f""" + SELECT member_id AS id, location, member_features::jsonb AS features, gender, 'my' AS card_type, birth_year + FROM member_account + JOIN feature_card + ON member_account.my_card_id = feature_card.feature_card_id + WHERE location like '%{location}%' and gender IN ('{user_gender.lower()}', '{user_gender.upper()}') + """ + cards.extend([dict(record) for record in await database.fetch_all(query)]) + + if want_to_find == 'post': query = f""" SELECT id, location, member_features::jsonb AS features, gender, 'room' AS card_type, member_account.birth_year FROM shared_room_post @@ -212,16 +219,15 @@ async def fetch_data(user_id, card_type): JOIN member_account ON member_account.member_id = shared_room_post.publisher_id WHERE location like '%{location}%' and gender IN ('{user_gender.lower()}', '{user_gender.upper()}') """ - post_cards.extend([dict(record) for record in await database.fetch_all(query)]) + cards.extend([dict(record) for record in await database.fetch_all(query)]) - return user_cards, post_cards, dict(user) + return cards, dict(user) def fill_missing_values(df): imputer = SimpleImputer(strategy='mean') return imputer.fit_transform(df) -async def clustering(user_cards, post_cards, user_card): - cards = [*user_cards, *post_cards] +async def clustering(cards, user_card): if cards == []: cards = [{'id': 'male_default', 'features': None, 'gender': 'MALE', 'card_type': 'my', 'birth_year': '1999'}, @@ -327,13 +333,13 @@ async def root(): @app.get("/recommendation/update") async def update(): start = time.time() - user_cards, post_cards, user_card = ( - await fetch_data("kakao_0", "my") + cards, user_card = ( + await fetch_data("kakao_0", "my", "member") ) print('fetch complete') - await clustering(user_cards, post_cards, user_card) + await clustering(cards, user_card) print("clustering complete") print("time : ", time.time() - start) @@ -341,12 +347,12 @@ async def update(): @app.get("/fetch") async def fetch(): - user_cards, post_cards = ( - await fetch_data("kakao_0", "my") + cards, user_card = ( + await fetch_data("kakao_0", "my", "member") ) - print(user_cards) - print(generate_df_data(post_cards)) + print(cards) + print(generate_df_data(cards)) @app.get("/insert") async def insert():