Skip to content

Commit

Permalink
Merge pull request #99 from boostcampaitech5/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
seungki1011 authored Aug 2, 2023
2 parents 0fa24ac + f24e836 commit dd36644
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 44 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ nohup_run_server:
nohup uvicorn src.scratch.main:app --host 0.0.0.0 --port 8001 --reload > ./FastAPI_Uvicorn.log 2>&1 &

nohup_run_dreambooth_worker:
nohup celery -A src.scratch.worker_dreambooth.celery_app worker -l info -E > ./celery_worker_dreambooth.log 2>&1 &
nohup celery -A src.scratch.worker_dreambooth.dream_app worker -l info -E -Q dreambooth --concurrency=8 --prefetch-multiplier 1 > ./celery_worker_dreambooth.log 2>&1 &

nohup_run_sdxl_worker:
nohup celery -A src.scratch.worker_sdxl.celery_app worker -l info -E > ./celery_worker_sdxl.log 2>&1 &
nohup celery -A src.scratch.worker_sdxl.celery_app worker -l info -E -Q sdxl --concurrency=8 --prefetch-multiplier 1 > ./celery_worker_sdxl.log 2>&1 &

nohup_run_sd_worker:
nohup celery -A src.scratch.worker_sd.celery_app worker -l info -E > ./celery_worker_sd.log 2>&1 &
1 change: 1 addition & 0 deletions src/scratch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .gpt3_api import get_description, get_translation, get_vibes
from .translation import translate_genre_to_english
from .main import *
from .model import StableDiffusion, StableDiffusionXL
from .streamlit_frontend import main
Expand Down
2 changes: 1 addition & 1 deletion src/scratch/gpt3_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_description(

# message
message = [
f"Tell me 2 to 1 words (objects) that come to mind when you saw \n{lyrics}\n seperated by commas"
f"Tell me 2 objects that come to mind when you see these lyrics \n{lyrics}\n seperated by commas"
]

# Set up the API call
Expand Down
2 changes: 1 addition & 1 deletion src/scratch/htdocs/js/album.js
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ document.addEventListener("DOMContentLoaded", () => {
for (let i = 1; i <= 10; i++) {
cur_review = "starpoint_" + i
if (document.getElementById(cur_review).checked == true) {
user_starpoint = Math.round(i / 2);
user_starpoint = (i / 2);
break;
}
}
Expand Down
77 changes: 58 additions & 19 deletions src/scratch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
gcp_config = load_yaml(os.path.join("src/scratch/config", "private.yaml"), "gcp")
public_config = load_yaml(os.path.join("src/scratch/config", "public.yaml"))
train_config = load_yaml(os.path.join("src/scratch/dreambooth", "dreambooth.yaml"))
redis_config = load_yaml(os.path.join("src/scratch/config", "private.yaml"), "redis")
bigquery_config = gcp_config["bigquery"]


Expand Down Expand Up @@ -68,10 +69,36 @@
# Initialize Celery
celery_app = Celery(
"tasks",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/0",
broker="redis://kimseungki1011:cv03@localhost:6379/0",
backend="redis://kimseungki1011:cv03@localhost:6379/1",
timezone="Asia/Seoul", # Set the time zone to KST
enable_utc=False,
beat_schedule={
"check_worker_heartbeats": {
"task": "celery.ping",
"schedule": 180, # Check worker heartbeats every 60 seconds
},
},
)

dream_app = Celery(
"tasks_dream",
broker="redis://kimseungki1011:cv03@localhost:6379/0",
backend="redis://kimseungki1011:cv03@localhost:6379/1",
timezone="Asia/Seoul", # Set the time zone to KST
enable_utc=False,
beat_schedule={
"check_worker_heartbeats": {
"task": "celery.ping",
"schedule": 180, # Check worker heartbeats every 60 seconds
},
},
)


# Set Celery Time-zone
celery_app.conf.timezone = "Asia/Seoul"


def get_random_string(length):
# choose from all lowercase letter
Expand Down Expand Up @@ -284,37 +311,49 @@ async def upload_image(image: UploadFile = File(...)):
image_content = base64.b64encode(image_bytes).decode()

# Use asyncio.gather to run the task asynchronously
task = celery_app.send_task(
task = dream_app.send_task(
"save_image", args=[image.filename, image_content, token]
)
asyncio.create_task(
wait_for_task_completion(task)
) # Run the task in the background
# asyncio.create_task(
# wait_for_task_completion(task)
# ) # Run the task in the background
try:
# Wait for the task to complete with a timeout of 60 seconds
result = task.get(timeout=60)
# Process the result if needed
print(result)
except TimeoutError:
# Task took too long to complete
print("Task timed out.")
return {"status": "Task timed out. Please check the status later."}
except Exception as e:
# Handle any other exceptions
print("Error occurred:", e)

return {"status": "File upload started"}


@api_router.post("/train_inference")
async def train(input: UserAlbumInput):
# Use asyncio.gather to run the task asynchronously
task = celery_app.send_task(
task = dream_app.send_task(
"train_inference", args=[input.dict(), token, request_id]
)

return {"task_id": task.id}


# Helper function to wait for task completion
async def wait_for_task_completion(task):
try:
result = task.get()
# Process the result if needed
print(result)
except asyncio.TimeoutError:
# Task took too long to complete
print("Task timed out.")
except Exception as e:
# Handle any other exceptions
print("Error occurred:", e)
# async def wait_for_task_completion(task):
# try:
# result = task.get()
# # Process the result if needed
# print(result)
# except asyncio.TimeoutError:
# # Task took too long to complete
# print("Task timed out.")
# except Exception as e:
# # Handle any other exceptions
# print("Error occurred:", e)


app.include_router(api_router)
Expand Down
18 changes: 18 additions & 0 deletions src/scratch/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Genre Translation


def translate_genre_to_english(genre):
# Define the translation dictionary
genre_translation = {
"발라드": "ballad",
"댄스": "dance",
"트로트": "trot",
"랩/힙합": "rap&hiphop",
"인디음악": "indie-music",
"록/메탈": "rock&metal",
"포크/블루스": "folk&blues"
# Add more genre translations here if needed
}

# Use the get() method to handle cases where the genre is not in the dictionary
return genre_translation.get(genre, genre)
22 changes: 14 additions & 8 deletions src/scratch/worker_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,23 @@
login(token=huggingface_config["token"], add_to_git_credential=True)

# Initialize Celery
celery_app = Celery(
"tasks",
broker=redis_config["redis_server_ip"],
backend=redis_config["redis_server_ip"],
dream_app = Celery(
"tasks_dream",
broker="redis://kimseungki1011:[email protected]:6379/0",
backend="redis://kimseungki1011:[email protected]:6379/1",
timezone="Asia/Seoul", # Set the time zone to KST
enable_utc=False,
worker_heartbeat=280,
)
celery_app.conf.worker_pool = "solo"
dream_app.conf.worker_pool = "solo"

# Set Celery Time-zone
dream_app.conf.timezone = "Asia/Seoul"

device = "cuda" if cuda.is_available() else "cpu"


@celery_app.task(name="save_image")
@dream_app.task(name="save_image", queue="dreambooth")
def save_image(filename, image_content, token):
# Define the directory where to save the image
image_dir = Path("src/scratch/dreambooth/data/users") / token
Expand All @@ -75,7 +81,7 @@ def save_image(filename, image_content, token):
return {"image_url": str(image_dir / filename)}


@celery_app.task(name="train_inference")
@dream_app.task(name="train_inference", queue="dreambooth")
def train_inference(input, token, request_id):
try:
global model
Expand Down Expand Up @@ -198,4 +204,4 @@ def train_inference(input, token, request_id):


if __name__ == "__main__":
celery_app.worker_main(["-l", "info"])
dream_app.worker_main(["-l", "info"])
11 changes: 8 additions & 3 deletions src/scratch/worker_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@
public_config = load_yaml(os.path.join("src/scratch/config", "public.yaml"))


# Initialize Celery
celery_app = Celery(
"tasks",
broker=redis_config["redis_server_ip"],
backend=redis_config["redis_server_ip"],
broker="redis://kimseungki1011:cv03@localhost:6379/0",
backend="redis://kimseungki1011:cv03@localhost:6379/1",
timezone="Asia/Seoul", # Set the time zone to KST
enable_utc=False,
worker_heartbeat=280, # Set the heartbeat timeout in seconds (e.g., 180 seconds)
)

celery_app.conf.worker_pool = "solo"

# Set Celery Time-zone
celery_app.conf.timezone = "Asia/Seoul"

gcs_uploader = GCSUploader(gcp_config)
bigquery_logger = BigQueryLogger(gcp_config)

Expand Down
59 changes: 49 additions & 10 deletions src/scratch/worker_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .gcp.cloud_storage import GCSUploader
from .gcp.bigquery import BigQueryLogger
from .utils import load_yaml
from .translation import translate_genre_to_english


# Load config
Expand All @@ -39,12 +40,18 @@
# Initialize Celery
celery_app = Celery(
"tasks",
broker=redis_config["redis_server_ip"],
backend=redis_config["redis_server_ip"],
broker="redis://kimseungki1011:[email protected]:6379/0",
backend="redis://kimseungki1011:[email protected]:6379/1",
timezone="Asia/Seoul", # Set the time zone to KST
enable_utc=False,
worker_heartbeat=280,
)

celery_app.conf.worker_pool = "solo"

# Set Celery Time-zone
celery_app.conf.timezone = "Asia/Seoul"

gcs_uploader = GCSUploader(gcp_config)
bigquery_logger = BigQueryLogger(gcp_config)

Expand All @@ -57,7 +64,7 @@ def setup_worker_init(*args, **kwargs):
model.get_model()


@celery_app.task(name="generate_cover")
@celery_app.task(name="generate_cover", queue="sdxl")
def generate_cover(input, request_id):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = "cuda" if cuda.is_available() else "cpu"
Expand Down Expand Up @@ -85,20 +92,24 @@ def generate_cover(input, request_id):
input["song_name"],
input["genre"],
)
prompt = f"A photo of a {input['genre']} album cover with a {vibe} atmosphere visualized and {summarization} on it"
genre = translate_genre_to_english(input["genre"])

prompt = f"Korean music album photo of a {get_translation(input['artist_name'])} who sang {get_translation(input['song_name'])}, full body, on {summarization}, Bounced lighting, dutch angle, Aaton LTR"
new_prompt = f"A photo or picture of a {genre} album cover that has a {vibe} vibe visualzied and has {summarization} on it"
else:
prompt = f"A photo of a {input['genre']} album cover with a {vibe} atmosphere visualized and {summarization} on it"
prompt = f"Korean music album photo of a {get_translation(input['artist_name'])} who sang {get_translation(input['song_name'])}, full body, Bounced lighting, dutch angle, Aaton LTR"
new_prompt = f"A photo or picture of a {genre} album cover that has a {vibe} vibe visualzied and has {summarization} on it"

prompt = re.sub("\n", ", ", prompt)
prompt = re.sub("[ㄱ-ㅎ가-힣]+", " ", prompt)
prompt = re.sub("[()-]", " ", prompt)
prompt = re.sub("\s+", " ", prompt)

if len(prompt) <= 150:
if len(prompt) <= 200:
break

if model is None:
time.sleep(20)
# if model is None:
# time.sleep(20)

seeds = np.random.randint(
public_config["generate"]["max_seed"], size=public_config["generate"]["n_gen"]
Expand All @@ -111,10 +122,10 @@ def generate_cover(input, request_id):
with torch.no_grad():
image = model.pipeline(
prompt=prompt,
prompt_2=negative_prompt,
prompt_2=prompt,
height=public_config["generate"]["height"],
width=public_config["generate"]["width"],
num_inference_steps=20,
num_inference_steps=100,
generator=generator,
).images[0]

Expand All @@ -132,6 +143,34 @@ def generate_cover(input, request_id):
)
images.append(byte_arr)

for i, seed in enumerate(seeds):
generator = torch.Generator(device=device).manual_seed(int(seed))

# Generate Images
with torch.no_grad():
image = model.pipeline(
prompt=new_prompt,
prompt_2=new_prompt,
height=public_config["generate"]["height"],
width=public_config["generate"]["width"],
num_inference_steps=100,
generator=generator,
).images[0]

# Convert to base64-encoded string
byte_arr = io.BytesIO()
image.save(byte_arr, format=public_config["generate"]["save_format"])
byte_arr = byte_arr.getvalue()

# Upload to GCS
urls.append(
[
byte_arr,
f"{request_id}_image_{i+2}.{public_config['generate']['save_format']}",
]
)
images.append(byte_arr)

# Upload to GCS
image_urls = gcs_uploader.save_image_to_gcs(urls)

Expand Down

0 comments on commit dd36644

Please sign in to comment.