Skip to content

Commit

Permalink
Fix memory leak from reference cycle
Browse files Browse the repository at this point in the history
This commit is taken from an umerged PR in taco
(tensor-compiler#520).
  • Loading branch information
Ectras committed Feb 1, 2023
1 parent de40a77 commit c9d3f87
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,10 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr);
struct AccessTensorNode : public AccessNode {
AccessTensorNode(TensorBase tensor, const std::vector<IndexVar>& indices)
: AccessNode(tensor.getTensorVar(), indices, {}, false),
tensor(tensor) {}
tensorPtr(tensor.content) {}

AccessTensorNode(TensorBase tensor, const std::vector<std::shared_ptr<IndexVarInterface>>& indices)
: AccessNode(tensor.getTensorVar()), tensor(tensor) {
: AccessNode(tensor.getTensorVar()), tensorPtr(tensor.content) {
// Create the vector of IndexVar to assign to this->indexVars.
std::vector<IndexVar> ivars(indices.size());
for (size_t i = 0; i < indices.size(); i++) {
Expand Down Expand Up @@ -517,10 +517,21 @@ struct AccessTensorNode : public AccessNode {
this->indexVars = std::move(ivars);
}

TensorBase tensor;
// We hold a weak_ptr to the accessed TensorBase to avoid creating a reference
// cycle between the accessed TensorBase and this AccessTensorNode, since the
// TensorBase will store the AccessTensorNode (as part of an IndexExpr) as a
// field on itself. Not using a weak pointer results in leaking TensorBases.
std::weak_ptr<TensorBase::Content> tensorPtr;
TensorBase getTensor() const {
TensorBase tensor;
tensor.content = tensorPtr.lock();
return tensor;
}

virtual void setAssignment(const Assignment& assignment) {
tensor.syncDependentTensors();
auto tensor = this->getTensor();

tensor.syncDependentTensors();
Assignment assign = makeReductionNotation(assignment);

tensor.setNeedsPack(false);
Expand Down Expand Up @@ -751,7 +762,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
taco_iassert(isa<AccessTensorNode>(node)) << "Unknown subexpression";

if (!util::contains(arguments, node->tensorVar)) {
arguments.insert({node->tensorVar, to<AccessTensorNode>(node)->tensor});
arguments.insert({node->tensorVar, to<AccessTensorNode>(node)->getTensor()});
}

// Also add any tensors backing index sets of tensor accesses.
Expand All @@ -763,7 +774,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
}

// TODO (rohany): This seems like dead code.
TensorBase tensor = to<AccessTensorNode>(node)->tensor;
TensorBase tensor = to<AccessTensorNode>(node)->getTensor();
if (!util::contains(inserted, tensor)) {
inserted.insert(tensor);
operands.push_back(tensor);
Expand Down

0 comments on commit c9d3f87

Please sign in to comment.