Skip to content

Commit

Permalink
enable no bias in FC layout inference (VeriSilicon#294)
Browse files Browse the repository at this point in the history
Signed-off-by: yuenan.li <[email protected]>

Co-authored-by: yuenan.li <[email protected]>
  • Loading branch information
liyuenan2333 and yuenan.li authored Feb 21, 2022
1 parent 6e0ac09 commit fe31a47
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions src/tim/transform/ops/fullyconnected_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,15 @@ class FullyConnectedLayoutInfer : public OpLayoutInfer {
context_->SetPermuteVector(in, trans_pv);
}
}
uint32_t axis = op_->impl()->node()->nn_param.fcl.axis;
uint32_t weight = op_->impl()->node()->nn_param.fcl.weights;

auto fcl = context_->infer_graph_->CreateOperation<vx::ops::FullyConnected>(
axis, weight);
auto fcl = op_->Clone(context_->infer_graph_);
auto required_pv =
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
auto out_infer = CreateOutputsTensor(required_pv);
(*fcl)
.BindInputs({context_->GetMapedTensor(op_->impl()->InputsTensor()[0]),
context_->GetMapedTensor(op_->impl()->InputsTensor()[1]),
context_->GetMapedTensor(op_->impl()->InputsTensor()[2])})
.BindOutput(out_infer[0]);
for (auto in : op_->impl()->InputsTensor()) {
(*fcl).BindInput(context_->GetMapedTensor(in));
}
(*fcl).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
Expand Down

0 comments on commit fe31a47

Please sign in to comment.