From f60283eeee62359145486316166667e926b67eaf Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Thu, 14 Nov 2024 05:48:36 +0800 Subject: [PATCH] refactor: remove the queue in LanceArrowWriter to reduce memory usage for spark sink (#3110) Remove the queue in LanceArrowWriter since it may cache all rows in queue and that will require a lot of jvm memory. Use mutex to control the write rate of sinker. Writer will wait util the reader take the batch. And more I had moved the `maven-shade-plugin` into a new profile which is diabled by default because `jar-with-dependencie` was conflict with many jars in spark dependencie --------- Co-authored-by: Lei Xu --- .github/workflows/java-publish.yml | 2 +- java/spark/pom.xml | 79 ++++++++++--------- .../lance/spark/write/LanceArrowWriter.java | 74 +++++++++-------- 3 files changed, 84 insertions(+), 71 deletions(-) diff --git a/.github/workflows/java-publish.yml b/.github/workflows/java-publish.yml index 9d79911f54..47a4213e13 100644 --- a/.github/workflows/java-publish.yml +++ b/.github/workflows/java-publish.yml @@ -111,7 +111,7 @@ jobs: echo "use-agent" >> ~/.gnupg/gpg.conf echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf export GPG_TTY=$(tty) - mvn --batch-mode -DskipTests -Drust.release.build=true -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh + mvn --batch-mode -DskipTests -Drust.release.build=true -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh -P shade-jar env: SONATYPE_USER: ${{ secrets.SONATYPE_USER }} SONATYPE_TOKEN: ${{ secrets.SONATYPE_TOKEN }} diff --git a/java/spark/pom.xml b/java/spark/pom.xml index b4c1d73c0a..1a6685c307 100644 --- a/java/spark/pom.xml +++ b/java/spark/pom.xml @@ -34,6 +34,48 @@ 2.13 + + shade-jar + + false + + + + + org.apache.maven.plugins + maven-shade-plugin + + + uber-jar + + shade + + package + + ${project.artifactId}-${scala.compat.version}-${project.version}-jar-with-dependencies + + + + + + + + *:* + + LICENSE + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + @@ -53,41 +95,4 @@ test - - - - - org.apache.maven.plugins - maven-shade-plugin - - - uber-jar - - shade - - package - - ${project.artifactId}-${scala.compat.version}-${project.version}-jar-with-dependencies - - - - - - - - *:* - - LICENSE - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - - - diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java index 341aa11ab3..f7fa2e6f45 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java @@ -25,6 +25,8 @@ import java.io.IOException; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; /** @@ -41,6 +43,9 @@ public class LanceArrowWriter extends ArrowReader { private final AtomicLong totalBytesRead = new AtomicLong(); private ArrowWriter arrowWriter = null; + private final AtomicInteger count = new AtomicInteger(0); + private final Semaphore writeToken; + private final Semaphore loadToken; public LanceArrowWriter(BufferAllocator allocator, Schema schema, int batchSize) { super(allocator); @@ -49,60 +54,63 @@ public LanceArrowWriter(BufferAllocator allocator, Schema schema, int batchSize) this.schema = schema; // TODO(lu) batch size as config? this.batchSize = batchSize; + this.writeToken = new Semaphore(0); + this.loadToken = new Semaphore(0); } void write(InternalRow row) { Preconditions.checkNotNull(row); - synchronized (monitor) { - // TODO(lu) wait if too much elements in rowQueue - rowQueue.offer(row); - monitor.notify(); + try { + // wait util prepareLoadNextBatch to release write token, + writeToken.acquire(); + arrowWriter.write(row); + if (count.incrementAndGet() == batchSize) { + // notify loadNextBatch to take the batch + loadToken.release(); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); } } void setFinished() { - synchronized (monitor) { - finished = true; - monitor.notify(); - } + loadToken.release(); + finished = true; } @Override - protected void prepareLoadNextBatch() throws IOException { + public void prepareLoadNextBatch() throws IOException { super.prepareLoadNextBatch(); - // Do not use ArrowWriter.reset since it does not work well with Arrow JNI arrowWriter = ArrowWriter.create(this.getVectorSchemaRoot()); + // release batch size token for write + writeToken.release(batchSize); } @Override public boolean loadNextBatch() throws IOException { prepareLoadNextBatch(); - int rowCount = 0; - synchronized (monitor) { - while (rowCount < batchSize) { - while (rowQueue.isEmpty() && !finished) { - try { - monitor.wait(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Interrupted while waiting for data", e); - } - } - if (rowQueue.isEmpty() && finished) { - break; - } - InternalRow row = rowQueue.poll(); - if (row != null) { - arrowWriter.write(row); - rowCount++; + try { + if (finished && count.get() == 0) { + return false; + } + // wait util batch if full or finished + loadToken.acquire(); + arrowWriter.finish(); + if (!finished) { + count.set(0); + return true; + } else { + // true if it has some rows and return false if there is no record + if (count.get() > 0) { + count.set(0); + return true; + } else { + return false; } } + } catch (InterruptedException e) { + throw new RuntimeException(e); } - if (rowCount == 0) { - return false; - } - arrowWriter.finish(); - return true; } @Override