Skip to content

Commit 5e2d988

Browse files
q10facebook-github-bot
authored andcommitted
Update the rowwise adagrad optimizer to leverage optimizer state offloading, v3, backend (pytorch#4133)
Summary: Pull Request resolved: pytorch#4133 X-link: facebookresearch/FBGEMM#1214 This diff adds support for leveraging optimizer state offloading to make optimizer state updates, starting with the rowwise adagrad optimizer. - Add compile-time flag `kEnableOptimizerOffloading` to the table update kernel to enable handling optimizer offloading, starting with the rowwise adagrad case - Propagate the compile-time flag upwards to `embedding_backward_split_template.cu`, where it is a runtime user-supplied boolean argument Differential Revision: D74827718
1 parent e0e3e8c commit 5e2d988

7 files changed

+65
-33
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,32 @@ def rowwise_adagrad() -> Dict[str, Any]:
186186
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
187187
"""
188188
)
189-
split_precomputation += """
189+
split_precomputation += """
190+
// Define the rowwise adagrad optimizer state struct view
191+
struct OptimizerState {
192+
at::acc_type<cache_t, true> momentum;
193+
};
194+
195+
// Fetch the pointer to the optimizer state along the cache row
196+
[[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
197+
190198
const at::acc_type<cache_t, true> g_avg_square =
191199
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
192200
193201
at::acc_type<cache_t, true> multiplier = 0.0;
194202
at::acc_type<cache_t, true> correction = 0.0;
195203
if (threadIdx.x == 0) {
196-
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
197-
momentum1[idx] = new_sum_square_grads;
204+
auto new_sum_square_grads = g_avg_square;
205+
206+
// Update the optimizer state. Use optimizer state offloading only if enabled
207+
if (enable_optimizer_offloading) {
208+
new_sum_square_grads += optimizer->momentum;
209+
optimizer->momentum = new_sum_square_grads;
210+
} else {
211+
new_sum_square_grads += momentum1[idx];
212+
momentum1[idx] = new_sum_square_grads;
213+
}
214+
198215
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
199216
if (weight_decay_mode == 1) {
200217
// L2 regularization
@@ -251,9 +268,10 @@ def rowwise_adagrad() -> Dict[str, Any]:
251268
OptimItem(ArgType.FLOAT, "weight_decay", 0.0),
252269
OptimItem(ArgType.INT, "weight_decay_mode", 0),
253270
OptimItem(ArgType.FLOAT, "max_norm", 0.0),
271+
OptimItem(ArgType.BOOL, "enable_optimizer_offloading", False),
254272
],
255273
{
256-
"v1": "Tensor momentum1, float eps = 0, float learning_rate = 0, float weight_decay = 0.0, int weight_decay_mode = 0.0, float max_norm = 0.0"
274+
"v1": "Tensor momentum1, float eps = 0, float learning_rate = 0, float weight_decay = 0.0, int weight_decay_mode = 0.0, float max_norm = 0.0",
257275
},
258276
),
259277
"split_precomputation": split_precomputation,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ for (const auto d : c10::irange(D)) {
431431

432432
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
433433
{% if not dense %}
434-
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host")}}, int output_dtype = 0) -> ()");
434+
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host").replace("false", "False")}}, int output_dtype = 0) -> ()");
435435
{% else %}
436-
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host")}}) -> Tensor");
436+
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host").replace("false", "False")}}) -> Tensor");
437437
{% endif %}
438438
DISPATCH_TO_CPU("split_embedding_backward_codegen_{{ optimizer }}_cpu", split_embedding_backward_codegen_{{ optimizer }}_cpu);
439439
}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,13 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
226226
// The unified PT2 interface already accepts learning rate as tensor.
227227
const auto learning_rate_tensor = at::tensor({learning_rate}, at::TensorOptions().dtype(at::kFloat).device(at::kCPU));
228228
{%- endif %}
229+
230+
// V1 API is frozen. New features/functionability can only be enabled in V2 API
231+
// // New arguments are added here for compatibility
232+
{%- if "enable_optimizer_offloading" in args.split_function_arg_names %}
233+
const bool enable_optimizer_offloading = false;
234+
{%- endif %}
235+
229236
return SplitLookupFunction_{{ optimizer }}_Op::apply(
230237
host_weights,
231238
weights_placements,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,12 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
10941094
learning_rate_tensor.fill_(learning_rate);
10951095
{%- endif %}
10961096

1097+
// V1 API is frozen. New features/functionability can only be enabled in V2 API
1098+
// New arguments are added here for compatibility
1099+
{%- if "enable_optimizer_offloading" in args.split_function_arg_names %}
1100+
const bool enable_optimizer_offloading = false;
1101+
{%- endif %}
1102+
10971103
{%- if not dense %}
10981104
// Load the config value from JK once
10991105
static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2);

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
454454
ph_type_combo,
455455
kFixedMaxVecsPerThread,
456456
kThreadGroupSize,
457-
kUseVecBlocking
458-
)
457+
kUseVecBlocking)
459458
}}
460459
{%- endfor %}
461460
{%- endfor %}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,6 @@ Tensor {{ embedding_cuda_op }}(
969969
{%- endif %}
970970

971971
DISPATCH_OPTIMAL_KERNEL(max_D, [&] {
972-
973972
auto long_run_ids = at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options());
974973
auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt));
975974

@@ -1032,18 +1031,17 @@ Tensor {{ embedding_cuda_op }}(
10321031
)
10331032
%}
10341033

1035-
const auto backward_cta_per_row_kernel =
1036-
{{ cta_kernel }}
1037-
<emb_t,
1038-
grad_t,
1039-
cache_t,
1040-
index_t,
1041-
{%- for ph_name in args.placeholder_tensor_names %}
1042-
{{ ph_name + "_ph_t" }},
1043-
{%- endfor %}
1044-
kFixedMaxVecsPerThread,
1045-
kThreadGroupSize,
1046-
kUseVecBlocking>;
1034+
const auto backward_cta_per_row_kernel = {{ cta_kernel }}<
1035+
emb_t,
1036+
grad_t,
1037+
cache_t,
1038+
index_t,
1039+
{%- for ph_name in args.placeholder_tensor_names %}
1040+
{{ ph_name + "_ph_t" }},
1041+
{%- endfor %}
1042+
kFixedMaxVecsPerThread,
1043+
kThreadGroupSize,
1044+
kUseVecBlocking>;
10471045

10481046
// Compute shared memory size for cta_per_row
10491047
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
@@ -1150,18 +1148,18 @@ Tensor {{ embedding_cuda_op }}(
11501148
desc_suffix,
11511149
)
11521150
%}
1153-
auto backward_warp_per_row_kernel =
1154-
{{ warp_kernel }}
1155-
<emb_t,
1156-
grad_t,
1157-
cache_t,
1158-
index_t,
1159-
{%- for ph_name in args.placeholder_tensor_names %}
1160-
{{ ph_name + "_ph_t" }},
1161-
{%- endfor %}
1162-
kFixedMaxVecsPerThread,
1163-
kThreadGroupSize,
1164-
kUseVecBlocking>;
1151+
1152+
const auto backward_warp_per_row_kernel = {{ warp_kernel }}<
1153+
emb_t,
1154+
grad_t,
1155+
cache_t,
1156+
index_t,
1157+
{%- for ph_name in args.placeholder_tensor_names %}
1158+
{{ ph_name + "_ph_t" }},
1159+
{%- endfor %}
1160+
kFixedMaxVecsPerThread,
1161+
kThreadGroupSize,
1162+
kUseVecBlocking>;
11651163

11661164
// Compute shared memory size for warp_per_row
11671165
int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize;

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,11 @@ def invoke(
352352
{%- if "optim_bool" in args_pt2.unified_pt2.split_function_arg_names %}
353353
optim_bool: List[bool] = []
354354
{%- for name in args_pt2.unified_pt2.split_args_dict["optim_bool"] %}
355+
{%- if name == "enable_optimizer_offloading" %} # TODO: Remove this when the frontend lands
356+
optim_bool.append(False)
357+
{%- else %}
355358
optim_bool.append(dict_optim_bool["{{ name }}"])
359+
{%- endif %}
356360
{%- endfor %}
357361
{%- endif %}
358362

0 commit comments

Comments
 (0)