Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multinode support using ray #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 149 additions & 67 deletions closed/NVIDIA/code/stable-diffusion-xl/tensorrt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import array
import os
import time

import ray
import traceback
import numpy as np
import tensorrt as trt
import torch
Expand Down Expand Up @@ -401,7 +402,7 @@ def make_infer_await_h2d(self, infer_stream):
def await_infer_done(self, infer_done):
CUASSERT(cudart.cudaStreamWaitEvent(self.stream, infer_done, 0))


@ray.remote(num_gpus=1, runtime_env={"env_vars": {"PYTHONPATH": "/work"}})
class SDXLCore:
def __init__(self,
device_id: int,
Expand All @@ -412,7 +413,7 @@ def __init__(self,
use_graphs: bool = False,
verbose: bool = False,
verbose_nvtx: bool = False):

device_id = 0
CUASSERT(cudart.cudaSetDevice(device_id))
torch.autograd.set_grad_enabled(False)
self.device = "cuda"
Expand All @@ -424,7 +425,7 @@ def __init__(self,
self.verbose_nvtx = verbose_nvtx

self._verbose_info(f"[Device {self.device_id}] Initializing")

logging.info("initialized")
# NVTX components
if self.verbose_nvtx:
self.nvtx_markers = {}
Expand Down Expand Up @@ -452,8 +453,8 @@ def __init__(self,
self.copy_stream = SDXLCopyStream(device_id, gpu_batch_size)

# QSR components
self.response_queue = queue.Queue()
self.response_thread = threading.Thread(target=self._process_response, args=(), daemon=True)
# self.response_queue = queue.Queue()
# self.response_thread = threading.Thread(target=self._process_response, args=(), daemon=True)
# self.start_inference = threading.Condition()

# Initialize scheduler
Expand Down Expand Up @@ -487,13 +488,15 @@ def __init__(self,
self.engines['unet'].enable_cuda_graphs(self.buffers)

# Initialize QSR thread
self.response_thread.start()

# self.response_thread.start()
def get_total_samples(self):
return self.total_samples
def __del__(self):
pass
# exit all threads
self.response_queue.put(None)
self.response_queue.join()
self.response_thread.join()
# self.response_queue.put(None)
# self.response_queue.join()
# self.response_thread.join()

def _verbose_info(self, msg):
if self.verbose:
Expand Down Expand Up @@ -624,44 +627,77 @@ def _save_buffer_to_images(self):
nvtx_profile_stop("post_process", self.nvtx_markers)

def generate_images(self, samples):
CUASSERT(cudart.cudaSetDevice(self.device_id))
if self.verbose_nvtx:
nvtx_profile_start("read_tokens", self.nvtx_markers, color='yellow')
actual_batch_size = len(samples)
sample_indices = [q.index for q in samples]
sample_ids = [q.id for q in samples]
self._verbose_info(f"[Device {self.device_id}] Running inference on sample {sample_indices} with batch size {actual_batch_size}")

# TODO add copy stream support
prompt_tokens_clip1 = self.dataset.prompt_tokens_clip1[sample_indices, :].to(self.device)
prompt_tokens_clip2 = self.dataset.prompt_tokens_clip2[sample_indices, :].to(self.device)
negative_prompt_tokens_clip1 = self.dataset.negative_prompt_tokens_clip1[sample_indices, :].to(self.device)
negative_prompt_tokens_clip2 = self.dataset.negative_prompt_tokens_clip2[sample_indices, :].to(self.device)
logging.info("generate_images")
# print(samples)
# print(samples[0].id)
try:
CUASSERT(cudart.cudaSetDevice(self.device_id))
if self.verbose_nvtx:
nvtx_profile_start("read_tokens", self.nvtx_markers, color='yellow')
actual_batch_size = len(samples)
sample_indices = [q.index for q in samples]
sample_ids = [q.id for q in samples]
self._verbose_info(f"[Device {self.device_id}] Running inference on sample {sample_indices} with batch size {actual_batch_size}")

# TODO add copy stream support
prompt_tokens_clip1 = self.dataset.prompt_tokens_clip1[sample_indices, :].to(self.device)
prompt_tokens_clip2 = self.dataset.prompt_tokens_clip2[sample_indices, :].to(self.device)
negative_prompt_tokens_clip1 = self.dataset.negative_prompt_tokens_clip1[sample_indices, :].to(self.device)
negative_prompt_tokens_clip2 = self.dataset.negative_prompt_tokens_clip2[sample_indices, :].to(self.device)

if self.verbose_nvtx:
nvtx_profile_stop("read_tokens", self.nvtx_markers)
nvtx_profile_start("stage_clip_buffers", self.nvtx_markers, color='pink')
self._transfer_to_clip_buffer(
prompt_tokens_clip1,
prompt_tokens_clip2,
negative_prompt_tokens_clip1,
negative_prompt_tokens_clip2
)
if self.verbose_nvtx:
nvtx_profile_stop("stage_clip_buffers", self.nvtx_markers)
# nvtx_profile_start("generate_images", self.nvtx_markers)
if self.verbose_nvtx:
nvtx_profile_stop("read_tokens", self.nvtx_markers)
nvtx_profile_start("stage_clip_buffers", self.nvtx_markers, color='pink')
self._transfer_to_clip_buffer(
prompt_tokens_clip1,
prompt_tokens_clip2,
negative_prompt_tokens_clip1,
negative_prompt_tokens_clip2
)
if self.verbose_nvtx:
nvtx_profile_stop("stage_clip_buffers", self.nvtx_markers)
# nvtx_profile_start("generate_images", self.nvtx_markers)

self._encode_tokens(actual_batch_size)
self._denoise_latent(actual_batch_size) # runs self.denoising_steps inside
self._decode_latent(actual_batch_size)
self._encode_tokens(actual_batch_size)
self._denoise_latent(actual_batch_size) # runs self.denoising_steps inside
self._decode_latent(actual_batch_size)

self._save_buffer_to_images()
self._save_buffer_to_images()

# Report back to loadgen use sample_ids
# response = SDXLResponse(sample_ids=sample_ids,
# generated_images=self.copy_stream.vae_outputs,
# results_ready=self.copy_stream.d2h_event)
# self.response_queue.put(response)

qsr = []
actual_batch_size = len(samples)
self.total_samples += actual_batch_size
return self.copy_stream.vae_outputs

CUASSERT(cudart.cudaEventSynchronize(response.results_ready))
self._verbose_info(f"[Device {self.device_id}] Reporting back {actual_batch_size} samples")

if self.verbose_nvtx:
nvtx_profile_start("report_qsl", self.nvtx_markers, color='yellow')

for idx, sample_id in enumerate(response.sample_ids):
qsr.append(lg.QuerySampleResponse(sample_id,
response.generated_images[idx].data_ptr(),
response.generated_images[idx].nelement() * response.generated_images[idx].element_size()))

# breakpoint()
lg.QuerySamplesComplete(qsr)

if self.verbose_nvtx:
nvtx_profile_stop("report_qsl", self.nvtx_markers)


except Exception as e:
tb = traceback.format_exc()
logging.info(f"Actor died due to error: {e}\n{tb}")
raise

# Report back to loadgen use sample_ids
response = SDXLResponse(sample_ids=sample_ids,
generated_images=self.copy_stream.vae_outputs,
results_ready=self.copy_stream.d2h_event)
self.response_queue.put(response)

def warm_up(self, warm_up_iters):
CUASSERT(cudart.cudaSetDevice(self.device_id))
Expand Down Expand Up @@ -700,8 +736,9 @@ def __init__(self,
verbose_nvtx: bool = False,
enable_batcher: bool = False,
batch_timeout_threashold: float = -1):

self.devices = devices
ray.init()
self.num_gpus = int(ray.available_resources()["GPU"])
self.devices = range(self.num_gpus)
self.gpu_batch_size = gpu_batch_size
self.verbose = verbose
self.verbose_nvtx = verbose_nvtx
Expand All @@ -716,10 +753,13 @@ def __init__(self,
self.sample_count = 0
self.sdxl_cores = {}
self.core_threads = []

self.future = []
self.result_queue = queue.Queue()

# Initialize the cores
for device_id in self.devices:
self.sdxl_cores[device_id] = SDXLCore(device_id=device_id,
self.sdxl_cores[device_id] = SDXLCore.remote(device_id=device_id,
dataset=dataset,
gpu_engine_files=gpu_engine_files,
gpu_batch_size=self.gpu_batch_size,
Expand All @@ -728,12 +768,9 @@ def __init__(self,
verbose=self.verbose,
verbose_nvtx=self.verbose_nvtx)

# Start the cores
for device_id in self.devices:
thread = threading.Thread(target=self.process_samples, args=(device_id,))
# thread.daemon = True
self.core_threads.append(thread)
thread.start()

self.report_thread = threading.Thread(target=self.report_complete)
self.report_thread.start()

if self.enable_batcher:
self.batcher_threshold = batch_timeout_threashold # maximum seconds to form a batch
Expand All @@ -749,7 +786,8 @@ def _verbose_info(self, msg):

def warm_up(self):
for device_id in self.devices:
self.sdxl_cores[device_id].warm_up(warm_up_iters=2)
future = self.sdxl_cores[device_id].warm_up.remote(warm_up_iters=2)
ray.get(future)

def process_samples(self, device_id):
while True:
Expand All @@ -758,8 +796,30 @@ def process_samples(self, device_id):
# None in the queue indicates the SUT want us to exit
self.sample_queue.task_done()
break
self.sdxl_cores[device_id].generate_images(samples)
future = self.sdxl_cores[device_id].generate_images.remote(samples)
# output = ray.get(future)
self.sample_queue.task_done()
self.result_queue.put((samples, future))
def report_complete(self):
while True:

res = self.result_queue.get()
if res is None:
self.result_queue.task_done()
break
self.result_queue.task_done()
samples = res[0]
future = res[1]
output = ray.get(future)
qsr = []
for idx in range(len(samples)):
sample_id = samples[idx].id
generated_images = output
qsr.append(lg.QuerySampleResponse(sample_id,
generated_images[idx].data_ptr(),
generated_images[idx].nelement() * generated_images[idx].element_size()))
# breakpoint()
lg.QuerySamplesComplete(qsr)

def batch_samples(self):
batched_samples = self.batcher_queue.get()
Expand All @@ -784,13 +844,32 @@ def issue_queries(self, query_samples):
num_samples = len(query_samples)
self._verbose_info(f"[Server] Received {num_samples} samples")
self.sample_count += num_samples
for i in range(0, num_samples, self.gpu_batch_size):
# Construct batches
actual_batch_size = self.gpu_batch_size if num_samples - i > self.gpu_batch_size else num_samples - i
if self.enable_batcher:
self.batcher_queue.put(query_samples[i: i + actual_batch_size])
else:
self.sample_queue.put(query_samples[i: i + actual_batch_size])
logging.info(f"{num_samples}")

samples = []
total_batch = self.gpu_batch_size * self.num_gpus
t = num_samples // total_batch
left = num_samples % total_batch
idx = 0
for i in range(t):
for device_id in self.devices:
sample = query_samples[idx: idx + self.gpu_batch_size]
idx += self.gpu_batch_size
future = self.sdxl_cores[device_id].generate_images.remote(sample)
self.result_queue.put((sample, future))

batch_size = left // self.num_gpus
left = left % self.num_gpus
for i in self.devices:
len_ = batch_size
if i < left:
len_ += 1
sample = query_samples[idx: idx + len_]
idx += len_
future = self.sdxl_cores[i].generate_images.remote(sample)
self.result_queue.put((sample, future))



def flush_queries(self):
pass
Expand All @@ -799,16 +878,19 @@ def finish_test(self):
# exit all threads
self._verbose_info(f"SUT finished!")
logging.info(f"[Server] Received {self.sample_count} total samples")
for _ in self.core_threads:
self.sample_queue.put(None)
self.sample_queue.join()
# for _ in self.core_threads:
# self.sample_queue.put(None)
self.result_queue.put(None)
self.result_queue.join()
# self.sample_queue.join()
if self.enable_batcher:
self.batcher_queue.put(None)
self.batcher_thread.join()
for device_id in self.devices:
logging.info(f"[Device {device_id}] Reported {self.sdxl_cores[device_id].total_samples} samples")
logging.info(f"[Device {device_id}] Reported {ray.get(self.sdxl_cores[device_id].get_total_samples.remote())} samples")
for thread in self.core_threads:
thread.join()
self.report_thread.join()


if __name__ == '__main__':
Expand Down
Loading