diff --git a/tensorflow/core/framework/embedding/lockless_hash_map.h b/tensorflow/core/framework/embedding/lockless_hash_map.h index b5a263eb436..4278266d10f 100644 --- a/tensorflow/core/framework/embedding/lockless_hash_map.h +++ b/tensorflow/core/framework/embedding/lockless_hash_map.h @@ -83,6 +83,8 @@ class LocklessHashMap : public KVInterface { std::pair*>*, long unsigned int> it = hash_map_.GetSnapshot(); hash_map_dump = it.first; bucket_count = it.second; + key_list->reserve(bucket_count); + value_ptr_list->reserve(bucket_count); for (int64 j = 0; j < bucket_count; j++) { if (hash_map_dump[j].first != LocklessHashMap::EMPTY_KEY_ && hash_map_dump[j].first != LocklessHashMap::DELETED_KEY_) { diff --git a/tensorflow/core/kernels/kv_variable_ops.h b/tensorflow/core/kernels/kv_variable_ops.h index 57291940816..78d1731e389 100644 --- a/tensorflow/core/kernels/kv_variable_ops.h +++ b/tensorflow/core/kernels/kv_variable_ops.h @@ -39,6 +39,22 @@ namespace { const int kSavedPartitionNum = 1000; } +template +std::vector concatenateVectors(const std::vector >& vectors) { + size_t totalSize = 0; + for (const auto& vec : vectors) { + totalSize += vec.size(); + } + + std::vector result(totalSize); + auto it = result.begin(); + for (const auto& vec : vectors) { + it = std::copy(vec.begin(), vec.end(), it); + } + + return result; +} + template class EVKeyDumpIterator: public DumpIterator { public: @@ -226,44 +242,47 @@ Status DumpEmbeddingValues(EmbeddingVar* ev, // so that we can dynamically load ev with changed partition number int64 filter_freq = ev->MinFreq(); for (size_t i = 0; i < tot_key_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == reinterpret_cast(-1)) { - // only forward, no backward, bypass - } else if (tot_valueptr_list[i] == nullptr) { - key_filter_list_parts[partid].push_back(tot_key_list[i]); - } else { - key_list_parts[partid].push_back(tot_key_list[i]); - valueptr_list_parts[partid].push_back(tot_valueptr_list[i]); - } - break; - } + if (tot_key_list[i] < 0) { + LOG(WARNING) << "find negative key [" << tot_key_list[i] << "] in EV (" + << tensor_key << ")"; + continue; + } + int partid = tot_key_list[i] % kSavedPartitionNum; + if (tot_valueptr_list[i] == reinterpret_cast(-1)) { + // only forward, no backward, bypass + } else if (tot_valueptr_list[i] == nullptr) { + key_filter_list_parts[partid].push_back(tot_key_list[i]); + } else { + key_list_parts[partid].push_back(tot_key_list[i]); + valueptr_list_parts[partid].push_back(tot_valueptr_list[i]); } } for (size_t i = 0; i < tot_version_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == nullptr) { - version_filter_list_parts[partid].push_back(tot_version_list[i]); - } else { - version_list_parts[partid].push_back(tot_version_list[i]); - } - break; - } + if (tot_key_list[i] < 0) { + LOG(WARNING) << "find negative key [" << tot_key_list[i] << "] in EV(" + << tensor_key << ")"; + continue; + } + int partid = tot_key_list[i] % kSavedPartitionNum; + if (tot_valueptr_list[i] == nullptr) { + version_filter_list_parts[partid].push_back(tot_version_list[i]); + } else { + version_list_parts[partid].push_back(tot_version_list[i]); } } for (size_t i = 0; i < tot_freq_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == nullptr) { - freq_filter_list_parts[partid].push_back(tot_freq_list[i]); - } else { - freq_list_parts[partid].push_back(tot_freq_list[i]); - } - break; - } + if (tot_key_list[i] < 0) { + LOG(WARNING) << "find negative key [" << tot_key_list[i] << "] in EV(" + << tensor_key << ")"; + continue; + } + int partid = tot_key_list[i] % kSavedPartitionNum; + if (tot_valueptr_list[i] == nullptr) { + freq_filter_list_parts[partid].push_back(tot_freq_list[i]); + } else { + freq_list_parts[partid].push_back(tot_freq_list[i]); } } // LOG(INFO) << "EV:" << tensor_key << ", key_list_parts:" << key_list_parts.size(); @@ -272,6 +291,23 @@ Status DumpEmbeddingValues(EmbeddingVar* ev, part_offset_flat(0) = 0; part_filter_offset[0] = 0; int ptsize = 0; + + for (int partid = 0; partid < kSavedPartitionNum; partid++) { + std::vector& key_list = key_list_parts[partid]; + std::vector& key_filter_list = key_filter_list_parts[partid]; + ptsize += key_list.size(); + part_offset_flat(partid + 1) = part_offset_flat(partid) + key_list.size(); + part_filter_offset[partid + 1] = part_filter_offset[partid] + key_filter_list.size(); + } + partitioned_tot_key_list=concatenateVectors(key_list_parts); + partitioned_tot_valueptr_list=concatenateVectors(valueptr_list_parts); + partitioned_tot_version_list=concatenateVectors(version_list_parts); + partitioned_tot_freq_list=concatenateVectors(freq_list_parts); + partitioned_tot_key_filter_list=concatenateVectors(key_filter_list_parts); + partitioned_tot_version_filter_list=concatenateVectors(version_filter_list_parts); + partitioned_tot_freq_filter_list=concatenateVectors(freq_filter_list_parts); + +/* for (int partid = 0; partid < kSavedPartitionNum; partid++) { std::vector& key_list = key_list_parts[partid]; std::vector& valueptr_list = valueptr_list_parts[partid]; @@ -306,6 +342,7 @@ Status DumpEmbeddingValues(EmbeddingVar* ev, part_offset_flat(partid + 1) = part_offset_flat(partid) + key_list.size(); part_filter_offset[partid + 1] = part_filter_offset[partid] + key_filter_list.size(); } +*/ // TODO: DB iterator not support partition_offset writer->Add(tensor_key + "-partition_offset", *part_offset_tensor); for(int i = 0; i < kSavedPartitionNum + 1; i++) {