Skip to content

Commit

Permalink
[FLINK-36856][runtime] CollectSinkOperatorFactory batch size and sock…
Browse files Browse the repository at this point in the history
…et timeout config fix
  • Loading branch information
gaborgsomogyi authored Dec 9, 2024
1 parent bbe1946 commit a1abb6a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.InputTypeConfigurable;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.RpcOptions;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.core.fs.FileSystem.WriteMode;
Expand Down Expand Up @@ -109,6 +110,7 @@
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
Expand Down Expand Up @@ -1439,8 +1441,13 @@ public void collectAsync(Collector<T> collector) {
String accumulatorName = "dataStreamCollect_" + UUID.randomUUID().toString();

StreamExecutionEnvironment env = getExecutionEnvironment();
MemorySize maxBatchSize =
env.getConfiguration().get(CollectSinkOperatorFactory.MAX_BATCH_SIZE);
Duration socketTimeout =
env.getConfiguration().get(CollectSinkOperatorFactory.SOCKET_TIMEOUT);
CollectSinkOperatorFactory<T> factory =
new CollectSinkOperatorFactory<>(serializer, accumulatorName);
new CollectSinkOperatorFactory<>(
serializer, accumulatorName, maxBatchSize, socketTimeout);
CollectSinkOperator<T> operator = (CollectSinkOperator<T>) factory.getOperator();
long resultFetchTimeout =
env.getConfiguration().get(RpcOptions.ASK_TIMEOUT_DURATION).toMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ public CollectSinkFunction(
this.accumulatorName = accumulatorName;
}

public long getMaxBytesPerBatch() {
return maxBytesPerBatch;
}

private void initBuffer() {
if (buffer != null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ public CollectSinkOperatorFactory(
this.socketTimeoutMillis = (int) socketTimeout.toMillis();
}

public int getSocketTimeoutMillis() {
return socketTimeoutMillis;
}

@Override
@SuppressWarnings("unchecked")
public <T extends StreamOperator<Object>> T createStreamOperator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
package org.apache.flink.api.datastream;

import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ExecutionOptions;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.collect.CollectSinkFunction;
import org.apache.flink.streaming.api.operators.collect.CollectSinkOperator;
import org.apache.flink.streaming.api.operators.collect.CollectSinkOperatorFactory;
import org.apache.flink.streaming.api.transformations.LegacySinkTransformation;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.TestLogger;
Expand All @@ -30,6 +36,7 @@
import org.junit.Assert;
import org.junit.Test;

import java.time.Duration;
import java.util.List;
import java.util.function.Consumer;

Expand Down Expand Up @@ -111,6 +118,31 @@ public void testBoundedCollectAndLimit() throws Exception {
results.size());
}

@Test
public void testAsyncCollectWithSinkConfigs() {
Configuration configuration = new Configuration();
configuration.set(CollectSinkOperatorFactory.SOCKET_TIMEOUT, Duration.ofMillis(2));
configuration.set(CollectSinkOperatorFactory.MAX_BATCH_SIZE, new MemorySize(3));
final StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment(configuration);

final DataStream<Integer> stream = env.fromData(1, 2, 3, 4, 5);
stream.collectAsync();

List<Transformation<?>> transformations = env.getTransformations();
Assert.assertEquals(1, transformations.size());
LegacySinkTransformation<?> transformation =
(LegacySinkTransformation<?>) transformations.get(transformations.size() - 1);
CollectSinkOperatorFactory<?> collectSinkOperatorFactory =
(CollectSinkOperatorFactory<?>) transformation.getOperatorFactory();
CollectSinkFunction<?> collectSinkFunction =
((CollectSinkFunction<?>)
((CollectSinkOperator<?>) collectSinkOperatorFactory.getOperator())
.getUserFunction());
Assert.assertEquals(2, collectSinkOperatorFactory.getSocketTimeoutMillis());
Assert.assertEquals(3, collectSinkFunction.getMaxBytesPerBatch());
}

@Test
public void testAsyncCollect() throws Exception {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
Expand Down

0 comments on commit a1abb6a

Please sign in to comment.