Skip to content

Enable configurable deterministic algorithm flag #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ key: "ENABLE_WEIGHT_SHARING"
}
```

* `ENABLE_DETERMINISTIC_ALGORITHMS`: Boolean flag to enable deterministic algorithm selection for TorchScript models. By default, deterministic algorithms are disabled.

When this flag is set to `true`, Triton will configure the PyTorch backend to use only deterministic algorithm implementations. This ensures that model outputs are reproducible across runs, at the cost of potential performance degradation. If any operation does not have a deterministic version, an error will be raised.

The section of model config file specifying this parameter will look like:

```
parameters: {
key: "ENABLE_DETERMINISTIC_ALGORITHMS"
value: {
string_value: "true"
}
}
```

* `ENABLE_CACHE_CLEANING`: Boolean flag to enable CUDA cache cleaning after each model execution.
If not specified, cache cleaning is disabled. This flag has no effect if model is on CPU.
Setting this flag to true will negatively impact the performance due to additional CUDA cache
Expand Down
32 changes: 32 additions & 0 deletions src/libtorch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class ModelState : public BackendModel {
bool EnabledCacheCleaning() { return enable_cache_cleaning_; }

bool EnabledWeightSharing() { return enable_weight_sharing_; }
bool EnableDeterministicAlgorithms()
{
return enable_deterministic_algorithms_;
}
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
{
return model_outputs_;
Expand Down Expand Up @@ -136,6 +140,9 @@ class ModelState : public BackendModel {
// Flag to indicate whether weight sharing is enabled. Defaults to false.
bool enable_weight_sharing_;

// Flag to indicate whether deterministic algorithms are enabled.
bool enable_deterministic_algorithms_;

// Flag pairs to indicate if various JIT settings are set and
// enabled respectively. Defaults to (false, true). Default behavior
// is to do nothing if not explicitly set.
Expand Down Expand Up @@ -233,6 +240,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model), enable_optimized_execution_(true),
enable_inference_mode_(true), enable_cudnn_(true),
enable_cache_cleaning_(false), enable_weight_sharing_(false),
enable_deterministic_algorithms_(false),
enable_tensor_fuser_pair_({false, true}),
enable_jit_profiling_pair_({false, true}),
enable_jit_executor_pair_({false, true})
Expand Down Expand Up @@ -455,6 +463,26 @@ ModelState::ParseParameters()
.c_str());
}

// If `ENABLE_DETERMINISTIC_ALGORITHMS` is not present in 'parameters' then
// no update is made to 'enable_deterministic_algorithms_'.
err = ParseParameter(
params, "ENABLE_DETERMINISTIC_ALGORITHMS",
&enable_deterministic_algorithms_);
if (err != nullptr) {
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
return err;
} else {
TRITONSERVER_ErrorDelete(err);
}
} else {
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Deterministic algorithms are ") +
(enable_deterministic_algorithms_ ? "enabled" : "disabled") +
" for model instance '" + Name() + "'")
.c_str());
}

// If 'ENABLE_JIT_PROFILING' is not present in 'parameters' then no update
// is made to 'enable_jit_profiling'.
bool enable_jit_profiling = false;
Expand Down Expand Up @@ -1588,6 +1616,10 @@ ModelInstanceState::Execute(
// enable/disable cudnn
at::globalContext().setUserEnabledCuDNN(model_state_->EnabledCudnn());

// enable/disable deterministic algorithms
at::globalContext().setDeterministicAlgorithms(
model_state_->EnableDeterministicAlgorithms(), false /* warn_only */);

// JIT. No change is made unless parameter is explicitly set.
if (std::get<0>(model_state_->EnabledJitProfiling())) {
torch::jit::getProfilingMode() =
Expand Down