diff --git a/src/tensor.cpp b/src/tensor.cpp index 257c396c3..9588bf224 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -476,10 +476,10 @@ static inline map getTensors(const IndexExpr& expr); struct AccessTensorNode : public AccessNode { AccessTensorNode(TensorBase tensor, const std::vector& indices) : AccessNode(tensor.getTensorVar(), indices, {}, false), - tensor(tensor) {} + tensorPtr(tensor.content) {} AccessTensorNode(TensorBase tensor, const std::vector>& indices) - : AccessNode(tensor.getTensorVar()), tensor(tensor) { + : AccessNode(tensor.getTensorVar()), tensorPtr(tensor.content) { // Create the vector of IndexVar to assign to this->indexVars. std::vector ivars(indices.size()); for (size_t i = 0; i < indices.size(); i++) { @@ -517,10 +517,21 @@ struct AccessTensorNode : public AccessNode { this->indexVars = std::move(ivars); } - TensorBase tensor; + // We hold a weak_ptr to the accessed TensorBase to avoid creating a reference + // cycle between the accessed TensorBase and this AccessTensorNode, since the + // TensorBase will store the AccessTensorNode (as part of an IndexExpr) as a + // field on itself. Not using a weak pointer results in leaking TensorBases. + std::weak_ptr tensorPtr; + TensorBase getTensor() const { + TensorBase tensor; + tensor.content = tensorPtr.lock(); + return tensor; + } + virtual void setAssignment(const Assignment& assignment) { - tensor.syncDependentTensors(); + auto tensor = this->getTensor(); + tensor.syncDependentTensors(); Assignment assign = makeReductionNotation(assignment); tensor.setNeedsPack(false); @@ -751,7 +762,7 @@ static inline map getTensors(const IndexExpr& expr) { taco_iassert(isa(node)) << "Unknown subexpression"; if (!util::contains(arguments, node->tensorVar)) { - arguments.insert({node->tensorVar, to(node)->tensor}); + arguments.insert({node->tensorVar, to(node)->getTensor()}); } // Also add any tensors backing index sets of tensor accesses. @@ -763,7 +774,7 @@ static inline map getTensors(const IndexExpr& expr) { } // TODO (rohany): This seems like dead code. - TensorBase tensor = to(node)->tensor; + TensorBase tensor = to(node)->getTensor(); if (!util::contains(inserted, tensor)) { inserted.insert(tensor); operands.push_back(tensor);