diff --git a/closed/NVIDIA/code/stable-diffusion-xl/tensorrt/backend.py b/closed/NVIDIA/code/stable-diffusion-xl/tensorrt/backend.py index a36d4e00e..007f7bdae 100644 --- a/closed/NVIDIA/code/stable-diffusion-xl/tensorrt/backend.py +++ b/closed/NVIDIA/code/stable-diffusion-xl/tensorrt/backend.py @@ -18,7 +18,8 @@ import array import os import time - +import ray +import traceback import numpy as np import tensorrt as trt import torch @@ -401,7 +402,7 @@ def make_infer_await_h2d(self, infer_stream): def await_infer_done(self, infer_done): CUASSERT(cudart.cudaStreamWaitEvent(self.stream, infer_done, 0)) - +@ray.remote(num_gpus=1, runtime_env={"env_vars": {"PYTHONPATH": "/work"}}) class SDXLCore: def __init__(self, device_id: int, @@ -412,7 +413,7 @@ def __init__(self, use_graphs: bool = False, verbose: bool = False, verbose_nvtx: bool = False): - + device_id = 0 CUASSERT(cudart.cudaSetDevice(device_id)) torch.autograd.set_grad_enabled(False) self.device = "cuda" @@ -424,7 +425,7 @@ def __init__(self, self.verbose_nvtx = verbose_nvtx self._verbose_info(f"[Device {self.device_id}] Initializing") - + logging.info("initialized") # NVTX components if self.verbose_nvtx: self.nvtx_markers = {} @@ -452,8 +453,8 @@ def __init__(self, self.copy_stream = SDXLCopyStream(device_id, gpu_batch_size) # QSR components - self.response_queue = queue.Queue() - self.response_thread = threading.Thread(target=self._process_response, args=(), daemon=True) + # self.response_queue = queue.Queue() + # self.response_thread = threading.Thread(target=self._process_response, args=(), daemon=True) # self.start_inference = threading.Condition() # Initialize scheduler @@ -487,13 +488,15 @@ def __init__(self, self.engines['unet'].enable_cuda_graphs(self.buffers) # Initialize QSR thread - self.response_thread.start() - + # self.response_thread.start() + def get_total_samples(self): + return self.total_samples def __del__(self): + pass # exit all threads - self.response_queue.put(None) - self.response_queue.join() - self.response_thread.join() + # self.response_queue.put(None) + # self.response_queue.join() + # self.response_thread.join() def _verbose_info(self, msg): if self.verbose: @@ -624,44 +627,77 @@ def _save_buffer_to_images(self): nvtx_profile_stop("post_process", self.nvtx_markers) def generate_images(self, samples): - CUASSERT(cudart.cudaSetDevice(self.device_id)) - if self.verbose_nvtx: - nvtx_profile_start("read_tokens", self.nvtx_markers, color='yellow') - actual_batch_size = len(samples) - sample_indices = [q.index for q in samples] - sample_ids = [q.id for q in samples] - self._verbose_info(f"[Device {self.device_id}] Running inference on sample {sample_indices} with batch size {actual_batch_size}") - - # TODO add copy stream support - prompt_tokens_clip1 = self.dataset.prompt_tokens_clip1[sample_indices, :].to(self.device) - prompt_tokens_clip2 = self.dataset.prompt_tokens_clip2[sample_indices, :].to(self.device) - negative_prompt_tokens_clip1 = self.dataset.negative_prompt_tokens_clip1[sample_indices, :].to(self.device) - negative_prompt_tokens_clip2 = self.dataset.negative_prompt_tokens_clip2[sample_indices, :].to(self.device) + logging.info("generate_images") + # print(samples) + # print(samples[0].id) + try: + CUASSERT(cudart.cudaSetDevice(self.device_id)) + if self.verbose_nvtx: + nvtx_profile_start("read_tokens", self.nvtx_markers, color='yellow') + actual_batch_size = len(samples) + sample_indices = [q.index for q in samples] + sample_ids = [q.id for q in samples] + self._verbose_info(f"[Device {self.device_id}] Running inference on sample {sample_indices} with batch size {actual_batch_size}") + + # TODO add copy stream support + prompt_tokens_clip1 = self.dataset.prompt_tokens_clip1[sample_indices, :].to(self.device) + prompt_tokens_clip2 = self.dataset.prompt_tokens_clip2[sample_indices, :].to(self.device) + negative_prompt_tokens_clip1 = self.dataset.negative_prompt_tokens_clip1[sample_indices, :].to(self.device) + negative_prompt_tokens_clip2 = self.dataset.negative_prompt_tokens_clip2[sample_indices, :].to(self.device) - if self.verbose_nvtx: - nvtx_profile_stop("read_tokens", self.nvtx_markers) - nvtx_profile_start("stage_clip_buffers", self.nvtx_markers, color='pink') - self._transfer_to_clip_buffer( - prompt_tokens_clip1, - prompt_tokens_clip2, - negative_prompt_tokens_clip1, - negative_prompt_tokens_clip2 - ) - if self.verbose_nvtx: - nvtx_profile_stop("stage_clip_buffers", self.nvtx_markers) - # nvtx_profile_start("generate_images", self.nvtx_markers) + if self.verbose_nvtx: + nvtx_profile_stop("read_tokens", self.nvtx_markers) + nvtx_profile_start("stage_clip_buffers", self.nvtx_markers, color='pink') + self._transfer_to_clip_buffer( + prompt_tokens_clip1, + prompt_tokens_clip2, + negative_prompt_tokens_clip1, + negative_prompt_tokens_clip2 + ) + if self.verbose_nvtx: + nvtx_profile_stop("stage_clip_buffers", self.nvtx_markers) + # nvtx_profile_start("generate_images", self.nvtx_markers) - self._encode_tokens(actual_batch_size) - self._denoise_latent(actual_batch_size) # runs self.denoising_steps inside - self._decode_latent(actual_batch_size) + self._encode_tokens(actual_batch_size) + self._denoise_latent(actual_batch_size) # runs self.denoising_steps inside + self._decode_latent(actual_batch_size) - self._save_buffer_to_images() + self._save_buffer_to_images() + + # Report back to loadgen use sample_ids + # response = SDXLResponse(sample_ids=sample_ids, + # generated_images=self.copy_stream.vae_outputs, + # results_ready=self.copy_stream.d2h_event) + # self.response_queue.put(response) + + qsr = [] + actual_batch_size = len(samples) + self.total_samples += actual_batch_size + return self.copy_stream.vae_outputs + + CUASSERT(cudart.cudaEventSynchronize(response.results_ready)) + self._verbose_info(f"[Device {self.device_id}] Reporting back {actual_batch_size} samples") + + if self.verbose_nvtx: + nvtx_profile_start("report_qsl", self.nvtx_markers, color='yellow') + + for idx, sample_id in enumerate(response.sample_ids): + qsr.append(lg.QuerySampleResponse(sample_id, + response.generated_images[idx].data_ptr(), + response.generated_images[idx].nelement() * response.generated_images[idx].element_size())) + + # breakpoint() + lg.QuerySamplesComplete(qsr) + + if self.verbose_nvtx: + nvtx_profile_stop("report_qsl", self.nvtx_markers) + + + except Exception as e: + tb = traceback.format_exc() + logging.info(f"Actor died due to error: {e}\n{tb}") + raise - # Report back to loadgen use sample_ids - response = SDXLResponse(sample_ids=sample_ids, - generated_images=self.copy_stream.vae_outputs, - results_ready=self.copy_stream.d2h_event) - self.response_queue.put(response) def warm_up(self, warm_up_iters): CUASSERT(cudart.cudaSetDevice(self.device_id)) @@ -700,8 +736,9 @@ def __init__(self, verbose_nvtx: bool = False, enable_batcher: bool = False, batch_timeout_threashold: float = -1): - - self.devices = devices + ray.init() + self.num_gpus = int(ray.available_resources()["GPU"]) + self.devices = range(self.num_gpus) self.gpu_batch_size = gpu_batch_size self.verbose = verbose self.verbose_nvtx = verbose_nvtx @@ -716,10 +753,13 @@ def __init__(self, self.sample_count = 0 self.sdxl_cores = {} self.core_threads = [] + + self.future = [] + self.result_queue = queue.Queue() # Initialize the cores for device_id in self.devices: - self.sdxl_cores[device_id] = SDXLCore(device_id=device_id, + self.sdxl_cores[device_id] = SDXLCore.remote(device_id=device_id, dataset=dataset, gpu_engine_files=gpu_engine_files, gpu_batch_size=self.gpu_batch_size, @@ -728,12 +768,9 @@ def __init__(self, verbose=self.verbose, verbose_nvtx=self.verbose_nvtx) - # Start the cores - for device_id in self.devices: - thread = threading.Thread(target=self.process_samples, args=(device_id,)) - # thread.daemon = True - self.core_threads.append(thread) - thread.start() + + self.report_thread = threading.Thread(target=self.report_complete) + self.report_thread.start() if self.enable_batcher: self.batcher_threshold = batch_timeout_threashold # maximum seconds to form a batch @@ -749,7 +786,8 @@ def _verbose_info(self, msg): def warm_up(self): for device_id in self.devices: - self.sdxl_cores[device_id].warm_up(warm_up_iters=2) + future = self.sdxl_cores[device_id].warm_up.remote(warm_up_iters=2) + ray.get(future) def process_samples(self, device_id): while True: @@ -758,8 +796,30 @@ def process_samples(self, device_id): # None in the queue indicates the SUT want us to exit self.sample_queue.task_done() break - self.sdxl_cores[device_id].generate_images(samples) + future = self.sdxl_cores[device_id].generate_images.remote(samples) + # output = ray.get(future) self.sample_queue.task_done() + self.result_queue.put((samples, future)) + def report_complete(self): + while True: + + res = self.result_queue.get() + if res is None: + self.result_queue.task_done() + break + self.result_queue.task_done() + samples = res[0] + future = res[1] + output = ray.get(future) + qsr = [] + for idx in range(len(samples)): + sample_id = samples[idx].id + generated_images = output + qsr.append(lg.QuerySampleResponse(sample_id, + generated_images[idx].data_ptr(), + generated_images[idx].nelement() * generated_images[idx].element_size())) + # breakpoint() + lg.QuerySamplesComplete(qsr) def batch_samples(self): batched_samples = self.batcher_queue.get() @@ -784,13 +844,32 @@ def issue_queries(self, query_samples): num_samples = len(query_samples) self._verbose_info(f"[Server] Received {num_samples} samples") self.sample_count += num_samples - for i in range(0, num_samples, self.gpu_batch_size): - # Construct batches - actual_batch_size = self.gpu_batch_size if num_samples - i > self.gpu_batch_size else num_samples - i - if self.enable_batcher: - self.batcher_queue.put(query_samples[i: i + actual_batch_size]) - else: - self.sample_queue.put(query_samples[i: i + actual_batch_size]) + logging.info(f"{num_samples}") + + samples = [] + total_batch = self.gpu_batch_size * self.num_gpus + t = num_samples // total_batch + left = num_samples % total_batch + idx = 0 + for i in range(t): + for device_id in self.devices: + sample = query_samples[idx: idx + self.gpu_batch_size] + idx += self.gpu_batch_size + future = self.sdxl_cores[device_id].generate_images.remote(sample) + self.result_queue.put((sample, future)) + + batch_size = left // self.num_gpus + left = left % self.num_gpus + for i in self.devices: + len_ = batch_size + if i < left: + len_ += 1 + sample = query_samples[idx: idx + len_] + idx += len_ + future = self.sdxl_cores[i].generate_images.remote(sample) + self.result_queue.put((sample, future)) + + def flush_queries(self): pass @@ -799,16 +878,19 @@ def finish_test(self): # exit all threads self._verbose_info(f"SUT finished!") logging.info(f"[Server] Received {self.sample_count} total samples") - for _ in self.core_threads: - self.sample_queue.put(None) - self.sample_queue.join() + # for _ in self.core_threads: + # self.sample_queue.put(None) + self.result_queue.put(None) + self.result_queue.join() + # self.sample_queue.join() if self.enable_batcher: self.batcher_queue.put(None) self.batcher_thread.join() for device_id in self.devices: - logging.info(f"[Device {device_id}] Reported {self.sdxl_cores[device_id].total_samples} samples") + logging.info(f"[Device {device_id}] Reported {ray.get(self.sdxl_cores[device_id].get_total_samples.remote())} samples") for thread in self.core_threads: thread.join() + self.report_thread.join() if __name__ == '__main__':