Skip to content

Commit

Permalink
修改张量的填充方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Feb 8, 2024
1 parent 795c2b5 commit 930767d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
8 changes: 5 additions & 3 deletions include/data/tensor_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,17 @@ std::tuple<std::shared_ptr<Tensor<T>>, std::shared_ptr<Tensor<T>>> TensorBroadca
TensorCreate<T>(tensor2->channels(), tensor1->rows(), tensor1->cols());
CHECK(tensor2->size() == tensor2->channels());
for (uint32_t c = 0; c < tensor2->channels(); ++c) {
new_tensor->slice(c).fill(tensor2->index(c));
T* new_tensor_ptr = new_tensor->matrix_raw_ptr(c);
std::fill(new_tensor_ptr, new_tensor_ptr + new_tensor->plane_size(), tensor2->index(c));
}
return {tensor1, new_tensor};
} else if (tensor1->rows() == 1 && tensor1->cols() == 1) {
std::shared_ptr<Tensor<T>> new_tensor =
TensorCreate<T>(tensor1->channels(), tensor2->rows(), tensor2->cols());
CHECK(tensor1->size() == tensor1->channels());
for (uint32_t c = 0; c < tensor1->channels(); ++c) {
new_tensor->slice(c).fill(tensor1->index(c));
for (uint32_t c = 0; c < tensor1->channels(); ++c) {
T* new_tensor_ptr = new_tensor->matrix_raw_ptr(c);
std::fill(new_tensor_ptr, new_tensor_ptr + new_tensor->plane_size(), tensor1->index(c));
}
return {new_tensor, tensor2};
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/runtime/runtime_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RuntimeGraph {
* @param param_path Path to the parameter file defining the graph structure
* @param bin_path Path to the bin file containing the graph weights
*/
RuntimeGraph(std::string param_path, std::string bin_path);
explicit RuntimeGraph(std::string param_path, std::string bin_path);

/**
* @brief Sets the inputs to the graph
Expand Down
6 changes: 3 additions & 3 deletions include/runtime/runtime_operand.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ namespace kuiper_infer {
*/
template <typename T>
struct RuntimeOperandBase {
RuntimeOperandBase() = default;
explicit RuntimeOperandBase() = default;

RuntimeOperandBase(std::string name, std::vector<int32_t> shapes,
std::vector<std::shared_ptr<Tensor<T>>> datas, RuntimeDataType type)
explicit RuntimeOperandBase(std::string name, std::vector<int32_t> shapes,
std::vector<std::shared_ptr<Tensor<T>>> datas, RuntimeDataType type)
: name(std::move(name)), shapes(std::move(shapes)), datas(std::move(datas)), type(type) {}

/// Name of the operand
Expand Down
5 changes: 2 additions & 3 deletions source/runtime/runtime_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ void RuntimeGraph::Build() {
// 初始化节点的输入和输出空间
RuntimeOperatorUtils<float>::InitOperatorInput(operators_);
RuntimeOperatorUtils<float>::InitOperatorOutput(graph_->ops, operators_);

graph_state_ = GraphState::Complete;
if (graph_ != nullptr) {
graph_.reset();
Expand All @@ -123,8 +122,8 @@ void RuntimeGraph::Build() {
}

template <typename T>
StatusCode ExecuteLayer(const std::shared_ptr<Layer<T>>& layer, const std::string& op_name,
const std::string& op_type, bool is_debug) {
StatusCode ExecuteLayer(const T& layer, const std::string& op_name, const std::string& op_type,
bool is_debug) {
CHECK(layer != nullptr);
StatusCode status;
if (is_debug) {
Expand Down

0 comments on commit 930767d

Please sign in to comment.