diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/storage/ClosableIterator.scala b/spark/src/main/scala/org/apache/spark/sql/delta/storage/ClosableIterator.scala index 618972f6489..c92840d544b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/storage/ClosableIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/storage/ClosableIterator.scala @@ -63,12 +63,7 @@ object ClosableIterator { implicit class IteratorFlatMapCloseOp[A](val closableIter: Iterator[A]) extends AnyVal { def flatMapWithClose[B](f: A => ClosableIterator[B]): ClosableIterator[B] = new ClosableIterator[B] { - private var iter_curr = - if (closableIter.hasNext) { - f(closableIter.next()) - } else { - null - } + private var iter_curr: ClosableIterator[B] = null override def next(): B = { if (!hasNext) { throw new NoSuchElementException @@ -77,6 +72,9 @@ object ClosableIterator { } @scala.annotation.tailrec override def hasNext: Boolean = { + if (iter_curr == null && closableIter.hasNext) { + iter_curr = f(closableIter.next()) + } if (iter_curr == null) { false } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/storage/LineClosableIteratorSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/storage/LineClosableIteratorSuite.scala index de9a9b1af67..2e7b08309cc 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/storage/LineClosableIteratorSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/storage/LineClosableIteratorSuite.scala @@ -122,6 +122,64 @@ abstract class LineClosableIteratorSuiteBase extends SparkFunSuite { iter.close() assert(closed == 1) } + + test("flatMapWithClose does not open any iterators on creation") { + var opened = 0 + var closed = 0 + val outerReader = new StringReader("b\na\nr") + createIter(outerReader).flatMapWithClose(_ => { + val innerReader = new StringReader("f\no\no") { + opened += 1 + override def close(): Unit = { + super.close() + closed += 1 + } + } + createIter(innerReader) + }) + assert(opened == 0) + assert(closed == 0) + } + + test("flatMapWithClose calls close only for opened iterators") { + var opened = 0 + var closed = 0 + val outerReader = new StringReader("b\na\nr") + val iter = createIter(outerReader).flatMapWithClose(_ => { + val innerReader = new StringReader("f\no\no") { + opened += 1 + override def close(): Unit = { + super.close() + closed += 1 + } + } + createIter(innerReader) + }) + assert(iter.take(5).toList == List("f", "o", "o", "f", "o")) + iter.close() + assert(opened == 2) + assert(closed == 2) + } + + test("flatMapWithClose calls close only for opened iterators - iter boundary") { + var opened = 0 + var closed = 0 + val outerReader = new StringReader("b\na\nr") + val iter = createIter(outerReader).flatMapWithClose(_ => { + val innerReader = new StringReader("f\no\no") { + opened += 1 + override def close(): Unit = { + super.close() + closed += 1 + } + } + createIter(innerReader) + }) + assert(iter.take(3).toList == List("f", "o", "o")) + iter.close() + assert(opened == 1) + assert(closed == 1) + } } class InternalLineClosableIteratorSuite extends LineClosableIteratorSuiteBase {