From 6bfb8ffb2fd36c009aacb6ea4059e44b5d9f0902 Mon Sep 17 00:00:00 2001
From: Tom Stepp <tom.j.stepp@gmail.com>
Date: Wed, 15 Jan 2025 00:23:24 +0000
Subject: [PATCH] Kafka source offset-based deduplication.

---
 .../org/apache/beam/sdk/coders/Coder.java     |  3 +-
 .../sdk/io/kafka/KafkaCheckpointMark.java     | 21 +++++++
 .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 55 ++++++++++++++---
 ...afkaIOReadImplementationCompatibility.java |  7 ++-
 .../beam/sdk/io/kafka/KafkaIOUtils.java       | 13 ++++
 .../sdk/io/kafka/KafkaUnboundedReader.java    | 40 +++++++++++-
 .../sdk/io/kafka/KafkaUnboundedSource.java    | 23 +++++--
 .../sdk/io/kafka/KafkaIOExternalTest.java     | 11 +++-
 ...IOReadImplementationCompatibilityTest.java | 15 ++++-
 .../apache/beam/sdk/io/kafka/KafkaIOTest.java | 61 ++++++++++++++-----
 .../io/kafka/upgrade/KafkaIOTranslation.java  |  6 ++
 11 files changed, 219 insertions(+), 36 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/Coder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/Coder.java
index 6bcdea0c0ab6..0a3650ca133b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/Coder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/Coder.java
@@ -198,7 +198,8 @@ public static void verifyDeterministic(Coder<?> target, String message, Iterable
     }
   }
 
-  public static <T> long getEncodedElementByteSizeUsingCoder(Coder<T> target, T value) throws Exception {
+  public static <T> long getEncodedElementByteSizeUsingCoder(Coder<T> target, T value)
+      throws Exception {
     return target.getEncodedElementByteSize(value);
   }
   /**
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java
index 966363e41f62..4271d6f72a03 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.io.kafka;
 
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
+
 import java.io.Serializable;
 import java.util.List;
 import java.util.Optional;
@@ -42,6 +44,8 @@ public class KafkaCheckpointMark implements UnboundedSource.CheckpointMark {
   @SuppressWarnings("initialization") // Avro will set the fields by breaking abstraction
   private KafkaCheckpointMark() {} // for Avro
 
+  private static final long OFFSET_DEDUP_PARTITIONS_PER_SPLIT = 1;
+
   public KafkaCheckpointMark(
       List<PartitionMark> partitions, Optional<KafkaUnboundedReader<?, ?>> reader) {
     this.partitions = partitions;
@@ -66,6 +70,23 @@ public String toString() {
     return "KafkaCheckpointMark{partitions=" + Joiner.on(",").join(partitions) + '}';
   }
 
+  @Override
+  public byte[] getOffsetLimit() {
+    if (!reader.isPresent()) {
+      throw new RuntimeException(
+          "KafkaCheckpointMark reader is not present while calling getOffsetLimit().");
+    }
+    if (!reader.get().offsetBasedDeduplicationSupported()) {
+      throw new RuntimeException(
+          "Unexpected getOffsetLimit() called while KafkaUnboundedReader not configured for offset deduplication.");
+    }
+
+    // KafkaUnboundedSource.split() must produce a 1:1 partition to split ratio.
+    checkState(partitions.size() == OFFSET_DEDUP_PARTITIONS_PER_SPLIT);
+    PartitionMark partition = partitions.get(/* index= */ 0);
+    return KafkaIOUtils.OffsetBasedDeduplication.encodeOffset(partition.getNextOffset());
+  }
+
   /**
    * A tuple to hold topic, partition, and offset that comprise the checkpoint for a single
    * partition.
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index cb7b3020c66a..27676c280275 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -610,6 +610,7 @@ public static <K, V> Read<K, V> read() {
         .setRedistributed(false)
         .setAllowDuplicates(false)
         .setRedistributeNumKeys(0)
+        .setOffsetDeduplication(false)
         .build();
   }
 
@@ -717,6 +718,9 @@ public abstract static class Read<K, V>
     @Pure
     public abstract int getRedistributeNumKeys();
 
+    @Pure
+    public abstract boolean isOffsetDeduplication();
+
     @Pure
     public abstract @Nullable Duration getWatchTopicPartitionDuration();
 
@@ -782,6 +786,8 @@ abstract Builder<K, V> setConsumerFactoryFn(
 
       abstract Builder<K, V> setRedistributeNumKeys(int redistributeNumKeys);
 
+      abstract Builder<K, V> setOffsetDeduplication(boolean offsetDeduplication);
+
       abstract Builder<K, V> setTimestampPolicyFactory(
           TimestampPolicyFactory<K, V> timestampPolicyFactory);
 
@@ -886,11 +892,16 @@ static <K, V> void setupExternalBuilder(
           if (config.allowDuplicates != null) {
             builder.setAllowDuplicates(config.allowDuplicates);
           }
-
+          if (config.redistribute
+              && (config.allowDuplicates == null || !config.allowDuplicates)
+              && config.offsetDeduplication != null) {
+            builder.setOffsetDeduplication(config.offsetDeduplication);
+          }
         } else {
           builder.setRedistributed(false);
           builder.setRedistributeNumKeys(0);
           builder.setAllowDuplicates(false);
+          builder.setOffsetDeduplication(false);
         }
       }
 
@@ -959,6 +970,7 @@ public static class Configuration {
         private Integer redistributeNumKeys;
         private Boolean redistribute;
         private Boolean allowDuplicates;
+        private Boolean offsetDeduplication;
 
         public void setConsumerConfig(Map<String, String> consumerConfig) {
           this.consumerConfig = consumerConfig;
@@ -1015,6 +1027,10 @@ public void setRedistribute(Boolean redistribute) {
         public void setAllowDuplicates(Boolean allowDuplicates) {
           this.allowDuplicates = allowDuplicates;
         }
+
+        public void setOffsetDeduplication(Boolean offsetDeduplication) {
+          this.offsetDeduplication = offsetDeduplication;
+        }
       }
     }
 
@@ -1066,26 +1082,21 @@ public Read<K, V> withTopicPartitions(List<TopicPartition> topicPartitions) {
      * Sets redistribute transform that hints to the runner to try to redistribute the work evenly.
      */
     public Read<K, V> withRedistribute() {
-      if (getRedistributeNumKeys() == 0 && isRedistributed()) {
-        LOG.warn("This will create a key per record, which is sub-optimal for most use cases.");
-      }
       return toBuilder().setRedistributed(true).build();
     }
 
     public Read<K, V> withAllowDuplicates(Boolean allowDuplicates) {
-      if (!isAllowDuplicates()) {
-        LOG.warn("Setting this value without setting withRedistribute() will have no effect.");
-      }
       return toBuilder().setAllowDuplicates(allowDuplicates).build();
     }
 
     public Read<K, V> withRedistributeNumKeys(int redistributeNumKeys) {
-      checkState(
-          isRedistributed(),
-          "withRedistributeNumKeys is ignored if withRedistribute() is not enabled on the transform.");
       return toBuilder().setRedistributeNumKeys(redistributeNumKeys).build();
     }
 
+    public Read<K, V> withOffsetDeduplication(boolean offsetDeduplication) {
+      return toBuilder().setOffsetDeduplication(offsetDeduplication).build();
+    }
+
     /**
      * Internally sets a {@link java.util.regex.Pattern} of topics to read from. All the partitions
      * from each of the matching topics are read.
@@ -1541,6 +1552,9 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
               ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG);
         }
       }
+
+      checkRedistributeConfiguration();
+
       warnAboutUnsafeConfigurations(input);
 
       // Infer key/value coders if not specified explicitly
@@ -1573,6 +1587,27 @@ && runnerPrefersLegacyRead(input.getPipeline().getOptions()))) {
       return input.apply(new ReadFromKafkaViaSDF<>(this, keyCoder, valueCoder));
     }
 
+    private void checkRedistributeConfiguration() {
+      if (getRedistributeNumKeys() == 0 && isRedistributed()) {
+        LOG.warn(
+            "withRedistribute without withRedistributeNumKeys will create a key per record, which is sub-optimal for most use cases.");
+      }
+      if (isAllowDuplicates()) {
+        checkState(
+            isRedistributed(), "withAllowDuplicates without withRedistribute will have no effect.");
+      }
+      if (getRedistributeNumKeys() > 0) {
+        checkState(
+            isRedistributed(),
+            "withRedistributeNumKeys is ignored if withRedistribute() is not enabled on the transform.");
+      }
+      if (isOffsetDeduplication()) {
+        checkState(
+            isRedistributed() && !isAllowDuplicates(),
+            "withOffsetDeduplication should only be used with withRedistribute and withAllowDuplicates(false).");
+      }
+    }
+
     private void warnAboutUnsafeConfigurations(PBegin input) {
       Long checkpointingInterval =
           input
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
index 457e0003705e..6702380c0897 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
@@ -137,7 +137,12 @@ Object getDefaultValue() {
         return false;
       }
     },
-    ;
+    OFFSET_DEDUPLICATION(LEGACY) {
+      @Override
+      Object getDefaultValue() {
+        return false;
+      }
+    };
 
     private final @NonNull ImmutableSet<KafkaIOReadImplementation> supportedImplementations;
 
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java
index 748418d16664..95f95000a58f 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java
@@ -19,11 +19,13 @@
 
 import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
 
+import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Longs;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
@@ -142,4 +144,15 @@ void update(double quantity) {
       return avg;
     }
   }
+
+  static final class OffsetBasedDeduplication {
+
+    static byte[] encodeOffset(long offset) {
+      return Longs.toByteArray(offset);
+    }
+
+    static byte[] getUniqueId(String topic, int partition, long offset) {
+      return (topic + "-" + partition + "-" + offset).getBytes(StandardCharsets.UTF_8);
+    }
+  }
 }
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
index ab9e26b3b740..9c82821b7d1a 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java
@@ -66,6 +66,7 @@
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.serialization.Deserializer;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -299,6 +300,34 @@ public Instant getCurrentTimestamp() throws NoSuchElementException {
     return curTimestamp;
   }
 
+  private static final byte[] EMPTY_RECORD_ID = new byte[0];
+
+  @Override
+  public byte[] getCurrentRecordId() throws NoSuchElementException {
+    if (!this.offsetBasedDeduplicationSupported) {
+      // BoundedReadFromUnboundedSource and tests may call getCurrentRecordId(), even for non offset
+      // deduplication cases. Therefore, Kafka reader cannot produce an exception when offset
+      // deduplication is disabled. Instead an empty record ID is provided.
+      return EMPTY_RECORD_ID;
+    }
+    if (curRecord != null) {
+      return KafkaIOUtils.OffsetBasedDeduplication.getUniqueId(
+          curRecord.getTopic(), curRecord.getPartition(), curRecord.getOffset());
+    }
+    throw new NoSuchElementException("KafkaUnboundedReader's curRecord is null.");
+  }
+
+  @Override
+  public byte[] getCurrentRecordOffset() throws NoSuchElementException {
+    if (!this.offsetBasedDeduplicationSupported) {
+      throw new RuntimeException("UnboundedSource must enable offset-based deduplication.");
+    }
+    if (curRecord != null) {
+      return KafkaIOUtils.OffsetBasedDeduplication.encodeOffset(curRecord.getOffset());
+    }
+    throw new NoSuchElementException("KafkaUnboundedReader's curRecord is null.");
+  }
+
   @Override
   public long getSplitBacklogBytes() {
     long backlogBytes = 0;
@@ -314,6 +343,10 @@ public long getSplitBacklogBytes() {
     return backlogBytes;
   }
 
+  public boolean offsetBasedDeduplicationSupported() {
+    return this.offsetBasedDeduplicationSupported;
+  }
+
   ////////////////////////////////////////////////////////////////////////////////////////////////
 
   private static final Logger LOG = LoggerFactory.getLogger(KafkaUnboundedReader.class);
@@ -332,10 +365,12 @@ public long getSplitBacklogBytes() {
   private final String name;
   private @Nullable Consumer<byte[], byte[]> consumer = null;
   private final List<PartitionState<K, V>> partitionStates;
-  private @Nullable KafkaRecord<K, V> curRecord = null;
-  private @Nullable Instant curTimestamp = null;
+  private @MonotonicNonNull KafkaRecord<K, V> curRecord = null;
+  private @MonotonicNonNull Instant curTimestamp = null;
   private Iterator<PartitionState<K, V>> curBatch = Collections.emptyIterator();
 
+  private final boolean offsetBasedDeduplicationSupported;
+
   private @Nullable Deserializer<K> keyDeserializerInstance = null;
   private @Nullable Deserializer<V> valueDeserializerInstance = null;
 
@@ -507,6 +542,7 @@ Instant updateAndGetWatermark() {
       KafkaUnboundedSource<K, V> source, @Nullable KafkaCheckpointMark checkpointMark) {
     this.source = source;
     this.name = "Reader-" + source.getId();
+    this.offsetBasedDeduplicationSupported = source.offsetBasedDeduplicationSupported();
 
     List<TopicPartition> partitions =
         Preconditions.checkArgumentNotNull(source.getSpec().getTopicPartitions());
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java
index 9685d859b0a1..dc59de133cda 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java
@@ -113,10 +113,20 @@ public List<KafkaUnboundedSource<K, V>> split(int desiredNumSplits, PipelineOpti
         partitions.size() > 0,
         "Could not find any partitions. Please check Kafka configuration and topic names");
 
-    int numSplits = Math.min(desiredNumSplits, partitions.size());
-    // XXX make all splits have the same # of partitions
-    while (partitions.size() % numSplits > 0) {
-      ++numSplits;
+    int numSplits;
+    if (offsetBasedDeduplicationSupported()) {
+      // Enforce 1:1 split to partition ratio for offset deduplication.
+      numSplits = partitions.size();
+      LOG.info(
+          "Offset-based deduplication is enabled for KafkaUnboundedSource. "
+              + "Forcing the number of splits to equal the number of total partitions: {}.",
+          numSplits);
+    } else {
+      numSplits = Math.min(desiredNumSplits, partitions.size());
+      // Make all splits have the same # of partitions.
+      while (partitions.size() % numSplits > 0) {
+        ++numSplits;
+      }
     }
     List<List<TopicPartition>> assignments = new ArrayList<>(numSplits);
 
@@ -177,6 +187,11 @@ public boolean requiresDeduping() {
     return false;
   }
 
+  @Override
+  public boolean offsetBasedDeduplicationSupported() {
+    return spec.isOffsetDeduplication();
+  }
+
   @Override
   public Coder<KafkaRecord<K, V>> getOutputCoder() {
     Coder<K> keyCoder = Preconditions.checkStateNotNull(spec.getKeyCoder());
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
index f021789a912c..77158e818a07 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
@@ -111,7 +111,8 @@ public void testConstructKafkaRead() throws Exception {
                         Field.of("consumer_polling_timeout", FieldType.INT64),
                         Field.of("redistribute_num_keys", FieldType.INT32),
                         Field.of("redistribute", FieldType.BOOLEAN),
-                        Field.of("allow_duplicates", FieldType.BOOLEAN)))
+                        Field.of("allow_duplicates", FieldType.BOOLEAN),
+                        Field.of("offset_deduplication", FieldType.BOOLEAN)))
                 .withFieldValue("topics", topics)
                 .withFieldValue("consumer_config", consumerConfig)
                 .withFieldValue("key_deserializer", keyDeserializer)
@@ -123,6 +124,7 @@ public void testConstructKafkaRead() throws Exception {
                 .withFieldValue("redistribute_num_keys", 0)
                 .withFieldValue("redistribute", false)
                 .withFieldValue("allow_duplicates", false)
+                .withFieldValue("offset_deduplication", false)
                 .build());
 
     RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
@@ -145,7 +147,7 @@ public void testConstructKafkaRead() throws Exception {
     expansionService.expand(request, observer);
     ExpansionApi.ExpansionResponse result = observer.result;
     RunnerApi.PTransform transform = result.getTransform();
-    System.out.println("xxx : " + result.toString());
+    System.out.println("Expansion result: " + result.toString());
     assertThat(
         transform.getSubtransformsList(),
         Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-Read.*")));
@@ -247,7 +249,8 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
                         Field.of("timestamp_policy", FieldType.STRING),
                         Field.of("redistribute_num_keys", FieldType.INT32),
                         Field.of("redistribute", FieldType.BOOLEAN),
-                        Field.of("allow_duplicates", FieldType.BOOLEAN)))
+                        Field.of("allow_duplicates", FieldType.BOOLEAN),
+                        Field.of("offset_deduplication", FieldType.BOOLEAN)))
                 .withFieldValue("topics", topics)
                 .withFieldValue("consumer_config", consumerConfig)
                 .withFieldValue("key_deserializer", keyDeserializer)
@@ -258,6 +261,7 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
                 .withFieldValue("redistribute_num_keys", 0)
                 .withFieldValue("redistribute", false)
                 .withFieldValue("allow_duplicates", false)
+                .withFieldValue("offset_deduplication", false)
                 .build());
 
     RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
@@ -281,6 +285,7 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
     ExpansionApi.ExpansionResponse result = observer.result;
     RunnerApi.PTransform transform = result.getTransform();
 
+    System.out.println("Expansion result: " + result.toString());
     assertThat(
         transform.getSubtransformsList(),
         Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-Read.*")));
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
index 29c920bf9a6f..cef427a0c575 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.kafka;
 
 import static org.apache.beam.sdk.io.kafka.KafkaIOTest.mkKafkaReadTransform;
+import static org.apache.beam.sdk.io.kafka.KafkaIOTest.mkKafkaReadTransformWithOffsetDedup;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
@@ -114,7 +115,8 @@ private PipelineResult testReadTransformCreationWithImplementationBoundPropertie
                 new ValueAsTimestampFn(),
                 false, /*redistribute*/
                 false, /*allowDuplicates*/
-                0)));
+                0, /*numKeys*/
+                false /*offsetDeduplication*/)));
     return p.run();
   }
 
@@ -139,6 +141,17 @@ public void testReadTransformCreationWithLegacyImplementationBoundProperty() {
     assertThat(Lineage.query(r.metrics(), Lineage.Type.SOURCE), containsInAnyOrder(expect));
   }
 
+  @Test
+  public void testReadTransformCreationWithOffsetDeduplication() {
+    p.apply(mkKafkaReadTransformWithOffsetDedup(1000, new ValueAsTimestampFn()));
+    PipelineResult r = p.run();
+    String[] expect =
+        KafkaIOTest.mkKafkaTopics.stream()
+            .map(topic -> String.format("kafka:`%s`.%s", KafkaIOTest.mkKafkaServers, topic))
+            .toArray(String[]::new);
+    assertThat(Lineage.query(r.metrics(), Lineage.Type.SOURCE), containsInAnyOrder(expect));
+  }
+
   @Test
   public void testReadTransformCreationWithSdfImplementationBoundProperty() {
     PipelineResult r =
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index e614320db150..6be5b7b2ecb9 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -391,7 +391,20 @@ static KafkaIO.Read<Integer, Long> mkKafkaReadTransform(
         timestampFn,
         false, /*redistribute*/
         false, /*allowDuplicates*/
-        0);
+        0, /*numKeys*/
+        false /*offsetDeduplication*/);
+  }
+
+  static KafkaIO.Read<Integer, Long> mkKafkaReadTransformWithOffsetDedup(
+      int numElements, @Nullable SerializableFunction<KV<Integer, Long>, Instant> timestampFn) {
+    return mkKafkaReadTransform(
+        numElements,
+        numElements,
+        timestampFn,
+        true, /*redistribute*/
+        false, /*allowDuplicates*/
+        100, /*numKeys*/
+        true /*offsetDeduplication*/);
   }
 
   /**
@@ -404,7 +417,8 @@ static KafkaIO.Read<Integer, Long> mkKafkaReadTransform(
       @Nullable SerializableFunction<KV<Integer, Long>, Instant> timestampFn,
       @Nullable Boolean redistribute,
       @Nullable Boolean withAllowDuplicates,
-      @Nullable Integer numKeys) {
+      @Nullable Integer numKeys,
+      @Nullable Boolean offsetDeduplication) {
 
     KafkaIO.Read<Integer, Long> reader =
         KafkaIO.<Integer, Long>read()
@@ -427,15 +441,15 @@ static KafkaIO.Read<Integer, Long> mkKafkaReadTransform(
       reader = reader.withTimestampFn(timestampFn);
     }
 
-    if (redistribute) {
+    if (redistribute != null && redistribute) {
+      reader = reader.withRedistribute();
+      reader = reader.withAllowDuplicates(withAllowDuplicates);
       if (numKeys != null) {
-        reader =
-            reader
-                .withRedistribute()
-                .withAllowDuplicates(withAllowDuplicates)
-                .withRedistributeNumKeys(numKeys);
+        reader = reader.withRedistributeNumKeys(numKeys);
       }
-      reader = reader.withRedistribute();
+    }
+    if (offsetDeduplication != null) {
+      reader.withOffsetDeduplication(offsetDeduplication);
     }
     return reader;
   }
@@ -667,7 +681,8 @@ public void warningsWithAllowDuplicatesEnabledAndCommitOffsets() {
                         new ValueAsTimestampFn(),
                         true, /*redistribute*/
                         true, /*allowDuplicates*/
-                        0)
+                        0, /*numKeys*/
+                        false /*offsetDeduplication*/)
                     .commitOffsetsInFinalize()
                     .withConsumerConfigUpdates(
                         ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, "group_id"))
@@ -693,7 +708,8 @@ public void noWarningsWithNoAllowDuplicatesAndCommitOffsets() {
                         new ValueAsTimestampFn(),
                         true, /*redistribute*/
                         false, /*allowDuplicates*/
-                        0)
+                        0, /*numKeys*/
+                        false /*offsetDeduplication*/)
                     .commitOffsetsInFinalize()
                     .withConsumerConfigUpdates(
                         ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, "group_id"))
@@ -720,7 +736,8 @@ public void testNumKeysIgnoredWithRedistributeNotEnabled() {
                         new ValueAsTimestampFn(),
                         false, /*redistribute*/
                         false, /*allowDuplicates*/
-                        0)
+                        0, /*numKeys*/
+                        false /*offsetDeduplication*/)
                     .withRedistributeNumKeys(100)
                     .commitOffsetsInFinalize()
                     .withConsumerConfigUpdates(
@@ -2091,7 +2108,8 @@ public void testUnboundedSourceStartReadTime() {
                         new ValueAsTimestampFn(),
                         false, /*redistribute*/
                         false, /*allowDuplicates*/
-                        0)
+                        0, /*numKeys*/
+                        false /*offsetDeduplication*/)
                     .withStartReadTime(new Instant(startTime))
                     .withoutMetadata())
             .apply(Values.create());
@@ -2100,6 +2118,20 @@ public void testUnboundedSourceStartReadTime() {
     p.run();
   }
 
+  @Test
+  public void testOffsetDeduplication() {
+    int numElements = 1000;
+
+    PCollection<Long> input =
+        p.apply(
+                mkKafkaReadTransformWithOffsetDedup(numElements, new ValueAsTimestampFn())
+                    .withoutMetadata())
+            .apply(Values.create());
+
+    addCountingAsserts(input, numElements, numElements, 0, numElements - 1);
+    p.run();
+  }
+
   @Rule public ExpectedException noMessagesException = ExpectedException.none();
 
   @Test
@@ -2121,7 +2153,8 @@ public void testUnboundedSourceStartReadTimeException() {
                     new ValueAsTimestampFn(),
                     false, /*redistribute*/
                     false, /*allowDuplicates*/
-                    0)
+                    0, /*numKeys*/
+                    false /*offsetDeduplication*/)
                 .withStartReadTime(new Instant(startTime))
                 .withoutMetadata())
         .apply(Values.create());
diff --git a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
index 841236969d25..3d10b96c00b5 100644
--- a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
+++ b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
@@ -101,6 +101,7 @@ static class KafkaIOReadWithMetadataTranslator implements TransformPayloadTransl
             .addBooleanField("redistribute")
             .addBooleanField("allows_duplicates")
             .addNullableInt32Field("redistribute_num_keys")
+            .addBooleanField("offset_deduplication")
             .addNullableLogicalTypeField("watch_topic_partition_duration", new NanosDuration())
             .addByteArrayField("timestamp_policy_factory")
             .addNullableMapField("offset_consumer_config", FieldType.STRING, FieldType.BYTES)
@@ -221,6 +222,7 @@ public Row toConfigRow(Read<?, ?> transform) {
       fieldValues.put("redistribute", transform.isRedistributed());
       fieldValues.put("redistribute_num_keys", transform.getRedistributeNumKeys());
       fieldValues.put("allows_duplicates", transform.isAllowDuplicates());
+      fieldValues.put("offset_deduplication", transform.isOffsetDeduplication());
       return Row.withSchema(schema).withFieldValues(fieldValues).build();
     }
 
@@ -349,6 +351,10 @@ public Row toConfigRow(Read<?, ?> transform) {
             }
           }
         }
+        Boolean offsetDeduplication = configRow.getValue("offset_deduplication");
+        if (offsetDeduplication != null) {
+          transform = transform.withOffsetDeduplication(offsetDeduplication);
+        }
         Duration maxReadTime = configRow.getValue("max_read_time");
         if (maxReadTime != null) {
           transform =