Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight streaming #3111

Merged
merged 19 commits into from
Oct 23, 2024
Merged

Conversation

keehyuna
Copy link
Collaborator

Description

  • Weight streaming feature is exposed as compiler option to set percent or weight streaming bytes
  • Create a network with kSTRONGLY_TYPED and set kWEIGHT_STREAMING to builder config
  • Same dtypes are required for layers in strongly typed network

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 22, 2024
@github-actions github-actions bot requested a review from apbose August 22, 2024 13:51
@@ -109,10 +111,119 @@ def __init__(
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def set_weight_streaming_budget(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keehyuna do you need to add something similar to the C++ API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry for confusion. It's dead code, all are moved to py/torch_tensorrt/runtime/_weight_streaming.py. C++ apis are updated in execute_engine.cpp

@github-actions github-actions bot added component: tests Issues re: Tests component: core Issues re: The core compiler labels Aug 26, 2024
@@ -95,11 +95,13 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
compiled_engine->init_context();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add first run latency. Why cant it run in the constructor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for advice. I added it in constructor. Latency is in compiler() context creation in forward() will be skipped when weight streaming is not used.

@@ -218,13 +219,25 @@ def set_weight_streaming_budget_v1(
self.engine.minimum_weight_streaming_budget
)

def reset_context(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these context resets atomically with whatever runtime settting change. Leave as much out of the forward function as we can

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came crossed two ideas and tried #1. Please let me know if there is better way to handle it automatically.

  1. reset_context(delete context) and apply set_weight_streaming_budget() api. context is created at forward()
  2. Enqueue runtime setting change like set_weight or profile enable. Then delete context->apply pending api-> create context in forward()

assert self.engine, f"Context is used before setting up the engine"

if self.context is None:
self.context = self.engine.create_execution_context()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have a setup engine function, not sure why we need to handle this at exec time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight streaming needs to be set before context is created. Or TRT throw the error. engine setup was completed in compile of runtime trt module. context needs to be recreated.

engine = runtime.deserialize_cuda_engine()
engine.weight_streaming_budget_v2 = budget_bytes
engine.create_execution_context()

fine

engine = runtime.deserialize_cuda_engine()
engine.create_execution_context()
engine.weight_streaming_budget_v2 = budget_bytes

ERROR:torch_tensorrt [TensorRT Conversion Context]:ICudaEngine::setWeightStreamingBudgetV2: Error Code 3: API Usage Error (Parameter check failed, condition: mExecutionContextCounter.use_count() == 1. The weight streaming budget cannot be modified while there are active IExecutionContexts.)

core/runtime/TRTEngine.cpp Outdated Show resolved Hide resolved
@@ -109,10 +111,119 @@ def __init__(
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def set_weight_streaming_budget(self) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry for confusion. It's dead code, all are moved to py/torch_tensorrt/runtime/_weight_streaming.py. C++ apis are updated in execute_engine.cpp

assert self.engine, f"Context is used before setting up the engine"

if self.context is None:
self.context = self.engine.create_execution_context()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight streaming needs to be set before context is created. Or TRT throw the error. engine setup was completed in compile of runtime trt module. context needs to be recreated.

engine = runtime.deserialize_cuda_engine()
engine.weight_streaming_budget_v2 = budget_bytes
engine.create_execution_context()

fine

engine = runtime.deserialize_cuda_engine()
engine.create_execution_context()
engine.weight_streaming_budget_v2 = budget_bytes

ERROR:torch_tensorrt [TensorRT Conversion Context]:ICudaEngine::setWeightStreamingBudgetV2: Error Code 3: API Usage Error (Parameter check failed, condition: mExecutionContextCounter.use_count() == 1. The weight streaming budget cannot be modified while there are active IExecutionContexts.)

@@ -95,11 +95,13 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
compiled_engine->init_context();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for advice. I added it in constructor. Latency is in compiler() context creation in forward() will be skipped when weight streaming is not used.

@@ -218,13 +219,25 @@ def set_weight_streaming_budget_v1(
self.engine.minimum_weight_streaming_budget
)

def reset_context(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came crossed two ideas and tried #1. Please let me know if there is better way to handle it automatically.

  1. reset_context(delete context) and apply set_weight_streaming_budget() api. context is created at forward()
  2. Enqueue runtime setting change like set_weight or profile enable. Then delete context->apply pending api-> create context in forward()

Comment on lines 124 to 149
def get_weight_streaming_budget(self):
return self.engine.streamable_weights_size

def set_weight_streaming_budget(self, budget_bytes):
self.reset_context()
self.engine.weight_streaming_budget_v2 = budget_bytes
if self.engine.weight_streaming_budget_v2 != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
budget_bytes = self.engine.weight_streaming_budget_v2
if self.engine.streamable_weights_size == budget_bytes:
logger.warning("Weight streaming is disabled")

return budget_bytes

def set_automatic_streaming_budget(self):
budget_bytes = self.engine.get_weight_streaming_automatic_budget()
return self.set_weight_streaming_budget(budget_bytes)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This api is same as in TorchTensorRTModule class. If this interface is good to go, parent class can used to share it and other some methods.

@narendasan
Copy link
Collaborator

We probably need to think about what the user flow is here:

So @ compile-time:

  1. Users tell us to make an engine that is weight streamable

@ runtime

  1. How do we set up the engine as a default? should there be a default weight budget?
  2. User now wants to explicitly set the engine weight budgets
    1. How can they do this from the module level? What happens if there are multiple engines in the graph?
    2. We need to recreate the execution context, imo this needs to be done atomically with this call to keep it out of the forward function. i.e. as part of set_weight_budget we recreate the execution context.

return budget_bytes

def set_automatic_streaming_budget(self):
budget_bytes = self.engine.get_weight_streaming_automatic_budget()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good default we can use in setup_engine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I set automatic weight streaming when compiler options is set

@@ -191,6 +221,7 @@ def __del__(self) -> None:
self.cudagraph.reset()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.init_context()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really want to pull these calls out, It should assume that the engine is setup and error if not

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. recreation of context happens only when set_weight_streaming_budget is called

@keehyuna
Copy link
Collaborator Author

Hi @narendasan

  1. When only compiler option is provided,

automatic weight streaming budget is applied. test case

  1. Weight stream size is set manually

torchtrt.runtime.weight_streaming_context() runtime api is added to get streamable size and set budget. test case

  1. Context recreation

added decorator in set_weight_streaming_budget.

  1. multiple subgraph

I'm thinking of applying normalized size if multiple module is in runtime module. impl. Too small value will have bad impact. I will test various size.

enable_weight_streaming=True,
)
# Weight streaming budget is applied manually.
ws_context = torchtrt.runtime.weight_streaming_context(optimized_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the context manager syntax to use this?

with torch_tensorrt.runtime.weight_streaming(model) as weight_streaming_ctx:
    current_budget = weight_streaming_ctx.device_budget
    weight_streaming_ctx.device_budget = current_budget * 0.7 # Can add listeners to __setattr__ to trigger functions
    optimized_model(*input)

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts :

  1. If we use weight streaming as default, is there any problem with perf ? assuming we don't allocate any budget or if automatic is chosen, and the model can fit on GPU memory completely

Comment on lines 159 to 164
cast_layer = ctx.net.add_cast(input_val, trt_dtype)
cast_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"

return cast_layer.get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are currently in llm_examples_main PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase with main as the llm_examples PR is merged

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py Outdated Show resolved Hide resolved
Comment on lines 159 to 164
cast_layer = ctx.net.add_cast(input_val, trt_dtype)
cast_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"

return cast_layer.get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase with main as the llm_examples PR is merged

Comment on lines 38 to 53
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
promoted_type = trt_inputs[0].dtype
for each_input in trt_inputs[1:]:
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(promoted_type).to(torch.dtype),
_enums.dtype._from(each_input.dtype).to(torch.dtype),
)
)

trt_promoted_type = promoted_type.to(trt.DataType)
trt_casted_inputs = []
for i, each_input in enumerate(trt_inputs):
casted_input = cast_trt_tensor(
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type promotion is fine but does it needs to only happen when strong typing is enabled? Why not do this in general cases as well ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought trt can optimize the perf for relaxed precision. But it seems multiple inputs in ops are eventually casted to same type. Tested sd unet model with/without promoted types, there was no differences. I will generalize.

Comment on lines 60 to 61
dtype = input.dtype if strongly_typed else None
bias = to_numpy(bias, dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the type of bias be always input.dtype ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If test fp16 variant of sd_unet model, bias data type is float16. It needs to be casted to run with weight streaming option.

@@ -85,6 +85,12 @@ def __init__(
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
flag |= EXPLICIT_BATCH

if compilation_settings.enable_weight_streaming:
STRONGLY_TYPED = 1 << (int)(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should log this at least since it affects the graph being created

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for separate compiler option to use strongly typed network.
https://github.com/pytorch/TensorRT/pull/3110/files#diff-4396607120a22430fe9fdb7d00b094ae5d55f28d0d2e3543a878ac48583ebd21R83
I will incorporate with it.

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor stuff at this point, think its almost ready to go

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this mostly looks good to me, anything outstanding?

@keehyuna
Copy link
Collaborator Author

No pending items. I think this PR can be merged.

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments:

  1. Added a comment in the example
  2. Also update this example reference in the docsrc/infex.rst to get rendered.
  3. Rebase with main to resolve conflicts.

Overall, changes LGTM

examples/dynamo/weight_streaming_example.py Outdated Show resolved Hide resolved
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Oct 18, 2024
@peri044 peri044 merged commit 92bf700 into pytorch:main Oct 23, 2024
67 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants