Skip to content

Commit

Permalink
fix ut tests and old client compatability
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongqiangczq committed Aug 27, 2023
1 parent 3efaf65 commit 7740c3b
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;

public class CelebornBufferStream {
Expand Down Expand Up @@ -131,13 +134,13 @@ private void cleanStream(long streamId) {
mapShuffleClient.getReadClientHandler().removeHandler(streamId);
clientFactory.unregisterSupplier(streamId);
closeStream(streamId);
isOpenSuccess = false;
}

public void close() {
synchronized (lock) {
if (isOpenSuccess) {
cleanStream(streamId);
isOpenSuccess = false;
}
isClosed = true;
}
Expand Down Expand Up @@ -173,36 +176,56 @@ private void openStreamInternal() throws IOException, InterruptedException {
locations[currentLocationIndex].getHost(),
locations[currentLocationIndex].getFetchPort());
String fileName = locations[currentLocationIndex].getFileName();
OpenStreamWithCredit openBufferStream =
new OpenStreamWithCredit(shuffleKey, fileName, subIndexStart, subIndexEnd, initialCredit);
TransportMessage openStream =
new TransportMessage(
MessageType.OPEN_STREAM,
PbOpenStream.newBuilder()
.setShuffleKey(shuffleKey)
.setFileName(fileName)
.setStartIndex(subIndexStart)
.setEndIndex(subIndexEnd)
.setInitialCredit(initialCredit)
.build()
.toByteArray());
client.sendRpc(
openBufferStream.toByteBuffer(),
openStream.toByteBuffer(),
new RpcResponseCallback() {

@Override
public void onSuccess(ByteBuffer response) {
StreamHandle streamHandle = (StreamHandle) Message.decode(response);
CelebornBufferStream.this.streamId = streamHandle.streamId;
synchronized (lock) {
if (!isClosed) {
clientFactory.registerSupplier(CelebornBufferStream.this.streamId, bufferSupplier);
mapShuffleClient
.getReadClientHandler()
.registerHandler(streamId, messageConsumer, client);
isOpenSuccess = true;
logger.debug(
"open stream success from remote:{}, stream id:{}, fileName: {}",
client.getSocketAddress(),
streamId,
fileName);
} else {
logger.debug(
"open stream success from remote:{}, but stream reader is already closed, stream id:{}, fileName: {}",
client.getSocketAddress(),
streamId,
fileName);
closeStream(streamId);
try {
PbStreamHandler pbStreamHandler =
TransportMessage.fromByteBuffer(response).getParsedPayload();
CelebornBufferStream.this.streamId = pbStreamHandler.getStreamId();
synchronized (lock) {
if (!isClosed) {
clientFactory.registerSupplier(
CelebornBufferStream.this.streamId, bufferSupplier);
mapShuffleClient
.getReadClientHandler()
.registerHandler(streamId, messageConsumer, client);
isOpenSuccess = true;
logger.debug(
"open stream success from remote:{}, stream id:{}, fileName: {}",
client.getSocketAddress(),
streamId,
fileName);
} else {
logger.debug(
"open stream success from remote:{}, but stream reader is already closed, stream id:{}, fileName: {}",
client.getSocketAddress(),
streamId,
fileName);
closeStream(streamId);
}
}
} catch (Exception e) {
logger.error(
"Open file {} stream for {} error from {}",
fileName,
shuffleKey,
NettyUtils.getRemoteAddress(client.getChannel()));
messageConsumer.accept(new TransportableError(streamId, e));
}
}

Expand Down
4 changes: 2 additions & 2 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ message PbReserveSlots {
bool rangeReadFilter = 8;
PbUserIdentifier userIdentifier = 9;
int64 pushDataTimeout = 10;
//now just used for flink client
bool splitEnabled=11;
// now just used for flink client
bool splitEnabled = 11;
}

message PbReserveSlotsResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.function.BiFunction
import scala.collection.JavaConverters._

import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.PartitionLocation
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import org.apache.celeborn.common.util.JavaUtils

class WorkerPartitionLocationInfo extends Logging {
Expand Down Expand Up @@ -183,8 +183,18 @@ class WorkerPartitionLocationInfo extends Logging {
|""".stripMargin
}

def getPrimaryPartitionLocations(): PartitionInfo = primaryPartitionLocations

def getReplicaPartitionLocations(): PartitionInfo = replicaPartitionLocations
def getPrimaryPartitionLocationsByFiler(f: String => Boolean)
: Array[Map[String, ConcurrentHashMap[String, PartitionLocation]]] = {
val primary = primaryPartitionLocations.asScala.filterKeys(f)
val replica = replicaPartitionLocations.asScala.filterKeys(f)
if (primary.size > 0) {
if (replica.size > 0) {
Array(primary, replica)
} else {
Array(primary)
}
}
Array()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,5 @@ public void invoke(Long value, Context context) throws Exception {
// Thread.sleep(30 * 1000);
}
});

env.execute("Shuffle Task");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.celeborn.tests.flink

import java.io.File

import scala.collection.JavaConverters._

import org.apache.flink.api.common.RuntimeExecutionMode
import org.apache.flink.api.common.{ExecutionMode, RuntimeExecutionMode}
import org.apache.flink.configuration.{Configuration, ExecutionOptions, RestOptions}
import org.apache.flink.runtime.jobgraph.{JobGraph, JobType}
import org.apache.flink.runtime.minicluster.{MiniCluster, MiniClusterConfiguration}
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.streaming.api.graph.{GlobalStreamExchangeMode, StreamingJobGraphGenerator}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite

Expand All @@ -35,6 +34,7 @@ import org.apache.celeborn.service.deploy.worker.Worker
class SplitTest extends AnyFunSuite with Logging with MiniClusterFeature
with BeforeAndAfterAll {
var workers: collection.Set[Worker] = null
var flinkCluster: MiniCluster = null
override def beforeAll(): Unit = {
logInfo("test initialized , setup rss mini cluster")
val masterConf = Map(
Expand All @@ -49,10 +49,13 @@ class SplitTest extends AnyFunSuite with Logging with MiniClusterFeature

override def afterAll(): Unit = {
logInfo("all test complete , stop rss mini cluster")
if (flinkCluster != null) {
flinkCluster.close()
}
shutdownMiniCluster()
}

ignore("celeborn flink integration test - shuffle partition split test") {
test("celeborn flink integration test - shuffle partition split test") {
val configuration = new Configuration
val parallelism = 8
configuration.setString(
Expand All @@ -69,10 +72,23 @@ class SplitTest extends AnyFunSuite with Logging with MiniClusterFeature
configuration.setString(
"execution.batch.adaptive.auto-parallelism.max-parallelism",
"" + parallelism)
configuration.setString(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "100k")
configuration.setString(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "10k")
configuration.setString(CelebornConf.CLIENT_FLINK_SHUFFLE_PARTITION_SPLIT_ENABLED.key, "true");
val env = StreamExecutionEnvironment.createLocalEnvironmentWithWebUI(configuration)
env.setRuntimeMode(RuntimeExecutionMode.BATCH)
env.getConfig.setExecutionMode(ExecutionMode.BATCH)
env.getConfig.setParallelism(parallelism)
SplitHelper.runSplitRead(env)
val miniClusterConfiguration =
(new MiniClusterConfiguration.Builder).setConfiguration(configuration).build()
flinkCluster = new MiniCluster(miniClusterConfiguration)
flinkCluster.start()
val graph = env.getStreamGraph
graph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING)
graph.setJobType(JobType.BATCH)
val jobGraph: JobGraph = StreamingJobGraphGenerator.createJobGraph(graph)
val jobID = flinkCluster.submitJob(jobGraph).get.getJobID
val jobResult = flinkCluster.requestJobResult(jobID).get
if (jobResult.getSerializedThrowable.isPresent)
throw new AssertionError(jobResult.getSerializedThrowable.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import scala.collection.JavaConverters._

import org.apache.flink.api.common.{ExecutionMode, RuntimeExecutionMode}
import org.apache.flink.configuration.{Configuration, ExecutionOptions, RestOptions}
import org.apache.flink.runtime.jobgraph.JobType
import org.apache.flink.runtime.jobgraph.{JobGraph, JobType}
import org.apache.flink.runtime.minicluster.{MiniCluster, MiniClusterConfiguration}
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode
import org.apache.flink.streaming.api.graph.{GlobalStreamExchangeMode, StreamingJobGraphGenerator}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite

Expand All @@ -36,7 +37,7 @@ import org.apache.celeborn.service.deploy.worker.Worker
class WordCountTest extends AnyFunSuite with Logging with MiniClusterFeature
with BeforeAndAfterAll {
var workers: collection.Set[Worker] = null

var flinkCluster: MiniCluster = null
override def beforeAll(): Unit = {
logInfo("test initialized , setup celeborn mini cluster")
val masterConf = Map(
Expand All @@ -48,6 +49,9 @@ class WordCountTest extends AnyFunSuite with Logging with MiniClusterFeature

override def afterAll(): Unit = {
logInfo("all test complete , stop celeborn mini cluster")
if (flinkCluster != null) {
flinkCluster.close()
}
shutdownMiniCluster()
}

Expand Down Expand Up @@ -76,7 +80,16 @@ class WordCountTest extends AnyFunSuite with Logging with MiniClusterFeature
val graph = env.getStreamGraph
graph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING)
graph.setJobType(JobType.BATCH)
env.execute(graph)
val miniClusterConfiguration =
(new MiniClusterConfiguration.Builder).setConfiguration(configuration).build()
flinkCluster = new MiniCluster(miniClusterConfiguration)
flinkCluster.start()

val jobGraph: JobGraph = StreamingJobGraphGenerator.createJobGraph(graph)
val jobID = flinkCluster.submitJob(jobGraph).get.getJobID
val jobResult = flinkCluster.requestJobResult(jobID).get
if (jobResult.getSerializedThrowable.isPresent)
throw new AssertionError(jobResult.getSerializedThrowable.get)
checkFlushingFileLength()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,19 @@ public long registerStream(
int initialCredit,
int startSubIndex,
int endSubIndex,
FileInfo fileInfo)
FileInfo fileInfo) {
return registerStream(
notifyStreamHandlerCallback, channel, initialCredit, startSubIndex, endSubIndex, fileInfo);
}

public long registerStream(
Consumer<Long> notifyStreamHandlerCallback,
Channel channel,
int initialCredit,
int startSubIndex,
int endSubIndex,
FileInfo fileInfo,
boolean isLegacy)
throws IOException {
long streamId = nextStreamId.getAndIncrement();
logger.debug(
Expand All @@ -102,7 +114,8 @@ public long registerStream(
threadsPerMountPoint,
fileInfo,
id -> recycleStream(id),
minBuffersToTriggerRead);
minBuffersToTriggerRead,
isLegacy);
} catch (IOException e) {
exception.set(e);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class MapDataPartition implements MemoryManager.ReadBufferTargetChangeListener {
private int maxReadBuffers;
private int minBuffersToTriggerRead;
private AtomicBoolean hasReadingTask = new AtomicBoolean(false);
private boolean isLegacy;

public MapDataPartition(
int minReadBuffers,
Expand All @@ -69,7 +70,8 @@ public MapDataPartition(
int threadsPerMountPoint,
FileInfo fileInfo,
Consumer<Long> recycleStream,
int minBuffersToTriggerRead)
int minBuffersToTriggerRead,
boolean isLegacy)
throws IOException {
this.recycleStream = recycleStream;
this.fileInfo = fileInfo;
Expand All @@ -78,12 +80,13 @@ public MapDataPartition(
this.maxReadBuffers = maxReadBuffers;

updateBuffersTarget((this.minReadBuffers + this.maxReadBuffers) / 2 + 1);

this.isLegacy = isLegacy;
logger.debug(
"read map partition {} with {} {} {}",
"read map partition {} with {} {} {} {}",
fileInfo.getFilePath(),
bufferQueue.getLocalBuffersTarget(),
fileInfo.getBufferSize());
fileInfo.getBufferSize(),
isLegacy);

this.minBuffersToTriggerRead = minBuffersToTriggerRead;

Expand Down Expand Up @@ -127,7 +130,8 @@ public void setupDataPartitionReader(
fileInfo,
streamId,
channel,
() -> recycleStream.accept(streamId));
() -> recycleStream.accept(streamId),
isLegacy);
readers.put(streamId, mapDataPartitionReader);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ public class MapDataPartitionReader implements Comparable<MapDataPartitionReader

private AtomicInteger numInUseBuffers = new AtomicInteger(0);
private boolean isOpen = false;
private boolean isLegacy;

public MapDataPartitionReader(
int startPartitionIndex,
int endPartitionIndex,
FileInfo fileInfo,
long streamId,
Channel associatedChannel,
Runnable recycleStream) {
Runnable recycleStream,
Boolean isLegacy) {
this.startPartitionIndex = startPartitionIndex;
this.endPartitionIndex = endPartitionIndex;

Expand All @@ -115,6 +117,7 @@ public MapDataPartitionReader(

this.fileInfo = fileInfo;
this.readFinished = false;
this.isLegacy = isLegacy;
}

public void open(FileChannel dataFileChannel, FileChannel indexFileChannel, long indexSize)
Expand Down Expand Up @@ -408,8 +411,9 @@ public FileInfo getFileInfo() {
public void closeReader() {
synchronized (lock) {
readFinished = true;
// tell client that this stream is finished.
associatedChannel.writeAndFlush(new BufferStreamEnd(streamId));
// old client can't support BufferStreamEnd, so for new client it tells client that this
// stream is finished.
if (!isLegacy) associatedChannel.writeAndFlush(new BufferStreamEnd(streamId));
}
logger.debug("Closed read for stream {}", this.streamId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
initialCredit,
startIndex,
endIndex,
fileInfo)
fileInfo,
isLegacy)
case PartitionType.MAPGROUP =>
}
} catch {
Expand Down
Loading

0 comments on commit 7740c3b

Please sign in to comment.