Skip to content

support filling partial rows from backend #4158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class KVZCHParams(NamedTuple):
bucket_sizes: List[int] = []
# enable optimizer offloading or not
enable_optimizer_offloading: bool = False
# streaming load/save checkpoint chunk size
streaming_ckpt_chunk_size: int = 1000000

def validate(self) -> None:
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
Expand Down
49 changes: 12 additions & 37 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ def _insert_all_kv(self) -> None:
total_dim0 += dim0

start_ts = time.time()
# TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
# if we have such cases, we should only init the emb dim not the optimizer dim
chunk_tensor = torch.empty(
row_count,
self.cache_row_dim,
Expand Down Expand Up @@ -1944,9 +1946,8 @@ def split_optimizer_states(

dtype = self.weights_precision.as_dtype()
optimizer_dim = self.optimizer.state_size_dim(dtype)
pad4_optimizer_dim = pad4(optimizer_dim)
logging.info(
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
)

for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
Expand All @@ -1972,7 +1973,7 @@ def split_optimizer_states(
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
)
else:
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
emb_opt_dim = pad4(emb_dim) + optimizer_dim
row_offset = table_offset - (bucket_id_start * bucket_size)
# using KVTensorWrapper to query backend to avoid OOM memory, since
# backend will return both weight and optimizer in one tensor, read the whole tensor
Expand All @@ -1984,54 +1985,28 @@ def split_optimizer_states(
snapshot_handle=snapshot_handle,
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
sorted_indices=sorted_id_tensor[t],
width_offset=pad4(emb_dim),
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
if self.backend_type == BackendType.SSD
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
)
opt_list.append(
self.get_offloaded_optimizer_states(
tensor_wrapper=tensor_wrapper,
row=sorted_id_tensor[t].size(
0
), # we only need to copy the size of sorted_id_tensor
optimizer_dim=optimizer_dim,
start_dim_pos=pad4(emb_dim),
tensor_wrapper.narrow(
0,
0,
sorted_id_tensor[t].size(0),
)
.view(-1)
.view(self.optimizer.dtype())
)
table_offset += emb_height
logging.info(
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
)
return opt_list

@torch.jit.export
def get_offloaded_optimizer_states(
self,
# pyre-ignore [2]
tensor_wrapper,
row: int,
optimizer_dim: int,
start_dim_pos: int,
) -> torch.Tensor:
weight_dtype = self.weights_precision.as_dtype()
opt_state_t = torch.empty(
row, optimizer_dim, dtype=weight_dtype, device="cpu"
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD

# pyre-ignore [16]
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
for i in range(0, row, chunk_size):
length = min(chunk_size, row - i)
opt_state_t.narrow(0, i, length).copy_(
tensor_wrapper.narrow(0, i, length).narrow(
1, start_dim_pos, optimizer_dim
)
)
# view optimizer state back to correct dtype
return opt_state_t.view(-1).view(self.optimizer.dtype())

@torch.jit.export
def get_optimizer_state(
self,
Expand Down Expand Up @@ -2207,7 +2182,7 @@ def split_embedding_weights(
if bucket_ascending_id_tensor is not None
else emb_height
),
emb_dim,
pad4(emb_dim),
],
dtype=dtype,
row_offset=table_offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ void EmbeddingKVDB::set(
<< "]skip set_cuda since number evictions is " << num_evictions;
return;
}

CHECK_EQ(max_D_, weights.size(1));
// defer the L2 cache/rocksdb update to the background thread as it could
// be parallelized with other cuda kernels, as long as all updates are
// finished before the next L2 cache lookup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,15 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const at::Tensor& weights,
const int64_t start,
const int64_t length,
const ssd::SnapshotHandle* snapshot_handle) {
const ssd::SnapshotHandle* snapshot_handle,
int64_t width_offset = 0,
std::optional<int64_t> width_length = std::nullopt) {
(void)weights;
(void)start;
(void)length;
(void)snapshot_handle;
(void)width_offset;
(void)width_length;
FBEXCEPTION("Not implemented");
}

Expand All @@ -287,10 +291,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
virtual void get_kv_from_storage_by_snapshot(
const at::Tensor& ids,
const at::Tensor& weights,
const ssd::SnapshotHandle* snapshot_handle) {
const ssd::SnapshotHandle* snapshot_handle,
int64_t width_offset = 0,
std::optional<int64_t> width_length = std::nullopt) {
(void)ids;
(void)weights;
(void)snapshot_handle;
(void)width_offset;
(void)width_length;
FBEXCEPTION("Not implemented");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
int64_t row_offset,
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle = std::nullopt,
std::optional<at::Tensor> sorted_indices = std::nullopt);
std::optional<at::Tensor> sorted_indices = std::nullopt,
int64_t width_offset = 0);

at::Tensor narrow(int64_t dim, int64_t start, int64_t length);

Expand Down Expand Up @@ -97,6 +98,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
std::vector<int64_t> strides_;
int64_t row_offset_;
std::optional<at::Tensor> sorted_indices_ = std::nullopt;
int64_t width_offset_;
};

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ KVTensorWrapper::KVTensorWrapper(
int64_t row_offset,
[[maybe_unused]] const std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle,
[[maybe_unused]] const std::optional<at::Tensor> sorted_indices)
[[maybe_unused]] const std::optional<at::Tensor> sorted_indices,
[[maybe_unused]] int64_t width_offset)
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
: shape_(std::move(shape)), row_offset_(row_offset) {
FBEXCEPTION("Not implemented");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,13 @@ KVTensorWrapper::KVTensorWrapper(
int64_t row_offset,
const std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle,
std::optional<at::Tensor> sorted_indices)
: db_(nullptr), shape_(std::move(shape)), row_offset_(row_offset) {
std::optional<at::Tensor> sorted_indices,
int64_t width_offset_)
: db_(nullptr),
shape_(std::move(shape)),
row_offset_(row_offset),
width_offset_(width_offset_) {
CHECK_GE(width_offset_, 0);
CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported";
options_ = at::TensorOptions()
.dtype(static_cast<c10::ScalarType>(dtype))
Expand Down Expand Up @@ -342,25 +347,28 @@ void KVTensorWrapper::set_dram_db_wrapper(

at::Tensor KVTensorWrapper::narrow(int64_t dim, int64_t start, int64_t length) {
CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported";
CHECK_GE(db_->get_max_D(), shape_[1]);
CHECK_TRUE(db_ != nullptr);
// Do not force snapshot handle is not nullptr since DRAM DB does not have
// rocksdb snapshot
CHECK_GE(db_->get_max_D(), shape_[1]);
TORCH_CHECK(
(snapshot_handle_ == nullptr) ==
(std::dynamic_pointer_cast<EmbeddingRocksDB>(db_).get() == nullptr),
"snapshot handler must be valid for rocksdb and nullptr for emb kvdb");
if (!sorted_indices_.has_value()) {
auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_);
int64_t tensor_width = shape_[1] - width_offset_;
auto t = at::empty(c10::IntArrayRef({length, tensor_width}), options_);
db_->get_range_from_snapshot(
t,
start + row_offset_,
length,
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr);
// TBE may have multiple embeddings in one table padded to max D
// narrow to the actual shape here before returning
return t.narrow(1, 0, shape_[1]).contiguous();
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr,
width_offset_,
tensor_width);
CHECK(t.is_contiguous());
return t;
} else {
at::Tensor sliced_ids =
sorted_indices_.value().slice(0, start, start + length);
auto out_weights = get_weights_by_ids(sliced_ids);
return out_weights.narrow(1, 0, shape_[1]).contiguous();
return get_weights_by_ids(sliced_ids);
}
}

Expand Down Expand Up @@ -404,14 +412,23 @@ void KVTensorWrapper::set_weights_and_ids(

at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) {
CHECK_TRUE(db_ != nullptr);
CHECK_GE(db_->get_max_D(), shape_[1]);
TORCH_CHECK(
(snapshot_handle_ == nullptr) ==
(std::dynamic_pointer_cast<EmbeddingRocksDB>(db_).get() == nullptr),
"snapshot handler must be valid for rocksdb and nullptr for emb kvdb");
int64_t tensor_width = shape_[1] - width_offset_;
auto weights =
at::empty(c10::IntArrayRef({ids.size(0), db_->get_max_D()}), options_);
at::empty(c10::IntArrayRef({ids.size(0), tensor_width}), options_);
auto linearized_ids = ids + row_offset_;
db_->get_kv_from_storage_by_snapshot(
linearized_ids,
weights,
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr);
return weights.narrow(1, 0, shape_[1]);
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr,
width_offset_,
tensor_width);
CHECK(weights.is_contiguous());
return weights;
}

c10::IntArrayRef KVTensorWrapper::sizes() {
Expand Down Expand Up @@ -634,15 +651,17 @@ static auto kv_tensor_wrapper =
int64_t,
std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>,
std::optional<at::Tensor>>(),
std::optional<at::Tensor>,
int64_t>(),
"",
{torch::arg("shape"),
torch::arg("dtype"),
torch::arg("row_offset"),
// snapshot must be provided for reading
// not needed for writing
torch::arg("snapshot_handle") = std::nullopt,
torch::arg("sorted_indices") = std::nullopt})
torch::arg("sorted_indices") = std::nullopt,
torch::arg("width_offset") = 0})
.def(
"set_embedding_rocks_dp_wrapper",
&KVTensorWrapper::set_embedding_rocks_dp_wrapper,
Expand Down
Loading
Loading