Skip to content

Commit

Permalink
replace all java collectors.toSet with kotlin construct
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed May 23, 2024
1 parent b0ab148 commit 6a0211a
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import io.airbyte.commons.util.MoreIterators
import java.util.*
import java.util.Spliterators.AbstractSpliterator
import java.util.function.Consumer
import java.util.stream.Collectors
import java.util.stream.Stream
import java.util.stream.StreamSupport
import org.bson.BsonDocument
Expand Down Expand Up @@ -57,9 +56,8 @@ class MongoDatabase(connectionString: String, databaseName: String) :
get() {
val collectionNames = database.listCollectionNames() ?: return Collections.emptySet()
return MoreIterators.toSet(collectionNames.iterator())
.stream()
.filter { c: String -> !c.startsWith(MONGO_RESERVED_COLLECTION_PREFIX) }
.collect(Collectors.toSet())
.toSet()
}

fun getCollection(collectionName: String): MongoCollection<Document> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,7 @@ abstract class AbstractJdbcSource<Datatype>(
): Map<String, MutableList<String>> {
LOGGER.info(
"Discover primary keys for tables: " +
tableInfos
.stream()
.map { obj: TableInfo<CommonField<Datatype>> -> obj.name }
.collect(Collectors.toSet())
tableInfos.map { obj: TableInfo<CommonField<Datatype>> -> obj.name }.toSet()
)
try {
// Get all primary keys without specifying a table name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ object DbSourceDiscoverUtil {
)
.withSourceDefinedPrimaryKey(primaryKeys)
}
// This is ugly. Some of our tests change the streams on the AirbyteCatalog
// object...
.toMutableList()
.toMutableList() // This is ugly, but we modify this list in
// JdbcSourceAcceptanceTest.testDiscoverWithMultipleSchemas
return AirbyteCatalog().withStreams(streams)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream
import io.airbyte.protocol.models.v0.SyncMode
import java.util.stream.Collectors

object RelationalDbReadUtil {
fun identifyStreamsToSnapshot(
Expand Down Expand Up @@ -38,11 +37,10 @@ object RelationalDbReadUtil {
): List<ConfiguredAirbyteStream> {
val initialLoadStreamsNamespacePairs =
streamsForInitialLoad
.stream()
.map { stream: ConfiguredAirbyteStream ->
AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream)
}
.collect(Collectors.toSet())
.toSet()
return catalog.streams
.stream()
.filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class CursorManager<S : Any>(
): Map<AirbyteStreamNameNamespacePair, CursorInfo> {
val allStreamNames =
catalog.streams
.stream()
.filter { c: ConfiguredAirbyteStream ->
if (onlyIncludeIncrementalStreams) {
return@filter c.syncMode == SyncMode.INCREMENTAL
Expand All @@ -103,7 +102,7 @@ class CursorManager<S : Any>(
.map { stream: AirbyteStream ->
AirbyteStreamNameNamespacePair.fromAirbyteStream(stream)
}
.collect(Collectors.toSet())
.toMutableSet()
allStreamNames.addAll(
streamSupplier
.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@ class GlobalStateManager(
): Set<AirbyteStreamNameNamespacePair> {
if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) {
return airbyteStateMessage.global.streamStates
.stream()
.map { streamState: AirbyteStreamState ->
val cloned = Jsons.clone(streamState)
AirbyteStreamNameNamespacePair(
cloned.streamDescriptor.name,
cloned.streamDescriptor.namespace
)
}
.collect(Collectors.toSet())
.toSet()
} else {
val legacyState: DbState? =
Jsons.`object`(airbyteStateMessage.data, DbState::class.java)
Expand All @@ -127,12 +126,11 @@ class GlobalStateManager(
streams: List<DbStreamState>
): Set<AirbyteStreamNameNamespacePair> {
return streams
.stream()
.map { stream: DbStreamState ->
val cloned = Jsons.clone(stream)
AirbyteStreamNameNamespacePair(cloned.streamName, cloned.streamNamespace)
}
.collect(Collectors.toSet())
.toSet()
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import io.airbyte.protocol.models.JsonSchemaType
import io.airbyte.protocol.models.v0.*
import java.util.*
import java.util.function.Consumer
import java.util.stream.Collectors
import java.util.stream.Stream
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.BeforeEach
Expand Down Expand Up @@ -123,9 +121,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
private fun assertStateDoNotHaveDuplicateStreams(stateMessage: AirbyteStateMessage) {
val dedupedStreamStates =
stateMessage.global.streamStates
.stream()
.map { streamState: AirbyteStreamState -> streamState.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertEquals(dedupedStreamStates.size, stateMessage.global.streamStates.size)
}

Expand Down Expand Up @@ -290,7 +287,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
if (message.type == AirbyteMessage.Type.RECORD) {
val recordMessage = message.record
recordsPerStream
.computeIfAbsent(recordMessage.stream) { c: String -> ArrayList() }
.computeIfAbsent(recordMessage.stream) { _: String -> ArrayList() }
.add(recordMessage)
}
}
Expand Down Expand Up @@ -326,10 +323,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
assertExpectedRecords(
expectedRecords,
actualRecords,
actualRecords
.stream()
.map { obj: AirbyteRecordMessage -> obj.stream }
.collect(Collectors.toSet()),
actualRecords.map { obj: AirbyteRecordMessage -> obj.stream }.toSet(),
)
}

Expand All @@ -356,8 +350,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
) {
val actualData =
actualRecords
.stream()
.map<JsonNode> { recordMessage: AirbyteRecordMessage ->
.map { recordMessage: AirbyteRecordMessage ->
Assertions.assertTrue(streamNames.contains(recordMessage.stream))
Assertions.assertNotNull(recordMessage.emittedAt)

Expand All @@ -374,7 +367,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
removeCDCColumns(data as ObjectNode)
data
}
.collect(Collectors.toSet())
.toSet()

Assertions.assertEquals(expectedRecords, actualData)
}
Expand Down Expand Up @@ -639,8 +632,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
// Non resumeable full refresh does not get any state messages.
assertExpectedStateMessageCountMatches(stateMessages1, MODEL_RECORDS.size.toLong())
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream())
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + MODEL_RECORDS).toSet(),
recordMessages1,
setOf(MODELS_STREAM_NAME),
names,
Expand All @@ -657,17 +649,15 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
assertExpectedStateMessagesFromIncrementalSync(stateMessages2)
assertExpectedStateMessageCountMatches(stateMessages2, 1)
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord))
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + puntoRecord).toSet(),
recordMessages2,
setOf(MODELS_STREAM_NAME),
names,
modelsSchema(),
)
} else {
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream())
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + MODEL_RECORDS).toSet(),
recordMessages1,
setOf(MODELS_STREAM_NAME),
names,
Expand All @@ -687,8 +677,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val stateMessages2 = extractStateMessages(actualRecords2)

assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord))
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + puntoRecord).toSet(),
recordMessages2,
setOf(MODELS_STREAM_NAME),
names,
Expand All @@ -706,7 +695,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val actualRecords3 = AutoCloseableIterators.toListAndClose(read3)
val recordMessages3 = extractRecordMessages(actualRecords3)
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream()).collect(Collectors.toSet()),
MODEL_RECORDS_2.toSet(),
recordMessages3,
setOf(MODELS_STREAM_NAME),
names,
Expand Down Expand Up @@ -790,8 +779,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
// Non resumeable full refresh does not get any state messages.
assertExpectedStateMessageCountMatches(stateMessages1, MODEL_RECORDS.size.toLong())
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream())
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + MODEL_RECORDS).toSet(),
recordMessages1,
setOf(MODELS_STREAM_NAME),
names,
Expand All @@ -808,8 +796,7 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
assertExpectedStateMessagesFromIncrementalSync(stateMessages2)
assertExpectedStateMessageCountMatches(stateMessages2, 1)
assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord))
.collect(Collectors.toSet()),
(MODEL_RECORDS_2 + puntoRecord).toSet(),
recordMessages2,
setOf(MODELS_STREAM_NAME),
names,
Expand Down Expand Up @@ -917,9 +904,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState)
val streamsInStateAfterFirstSyncCompletion =
stateMessageEmittedAfterFirstSyncCompletion.global.streamStates
.stream()
.map { obj: AirbyteStreamState -> obj.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertEquals(1, streamsInStateAfterFirstSyncCompletion.size)
Assertions.assertTrue(
streamsInStateAfterFirstSyncCompletion.contains(
Expand Down Expand Up @@ -1009,9 +995,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
HashSet(MODEL_RECORDS_RANDOM),
recordsForModelsRandomStreamFromSecondBatch,
recordsForModelsRandomStreamFromSecondBatch
.stream()
.map { obj: AirbyteRecordMessage -> obj.stream }
.collect(Collectors.toSet()),
.toSet(),
Sets.newHashSet(RANDOM_TABLE_NAME),
randomSchema(),
)
Expand Down Expand Up @@ -1078,9 +1063,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
)
val streamsInSyncCompletionStateAfterThirdSync =
stateMessageEmittedAfterThirdSyncCompletion.global.streamStates
.stream()
.map { obj: AirbyteStreamState -> obj.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertTrue(
streamsInSyncCompletionStateAfterThirdSync.contains(
StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()),
Expand Down Expand Up @@ -1109,9 +1093,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
recordsWrittenInRandomTable,
recordsForModelsRandomStreamFromThirdBatch,
recordsForModelsRandomStreamFromThirdBatch
.stream()
.map { obj: AirbyteRecordMessage -> obj.stream }
.collect(Collectors.toSet()),
.toSet(),
Sets.newHashSet(RANDOM_TABLE_NAME),
randomSchema(),
)
Expand All @@ -1138,9 +1121,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState)
val streamsInStateAfterFirstSyncCompletion =
stateMessageEmittedAfterFirstSyncCompletion.global.streamStates
.stream()
.map { obj: AirbyteStreamState -> obj.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertEquals(1, streamsInStateAfterFirstSyncCompletion.size)
Assertions.assertTrue(
streamsInStateAfterFirstSyncCompletion.contains(
Expand Down Expand Up @@ -1256,9 +1238,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
)
val streamsInSnapshotState =
stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.streamStates
.stream()
.map { obj: AirbyteStreamState -> obj.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertEquals(2, streamsInSnapshotState.size)
Assertions.assertTrue(
streamsInSnapshotState.contains(
Expand All @@ -1283,9 +1264,8 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
)
val streamsInSyncCompletionState =
stateMessageEmittedAfterSecondSyncCompletion.global.streamStates
.stream()
.map { obj: AirbyteStreamState -> obj.streamDescriptor }
.collect(Collectors.toSet())
.toSet()
Assertions.assertEquals(2, streamsInSnapshotState.size)
Assertions.assertTrue(
streamsInSyncCompletionState.contains(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import io.airbyte.protocol.models.Field
import io.airbyte.protocol.models.JsonSchemaType
import io.airbyte.protocol.models.v0.*
import java.math.BigDecimal
import java.nio.ByteBuffer
import java.sql.Connection
import java.util.*
import org.junit.jupiter.api.Assertions
Expand Down Expand Up @@ -172,7 +171,6 @@ abstract class JdbcStressTest {
}
.peek { m: AirbyteMessage -> assertExpectedMessage(m) }
.count()
var a: ByteBuffer
val expectedRoundedRecordsCount = TOTAL_RECORDS - TOTAL_RECORDS % 1000
LOGGER.info("expected records count: " + TOTAL_RECORDS)
LOGGER.info("actual records count: $actualCount")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ package io.airbyte.commons.enums
import com.google.common.base.Preconditions
import com.google.common.collect.Maps
import com.google.common.collect.Sets
import java.util.Arrays
import java.util.Locale
import java.util.Optional
import java.util.concurrent.ConcurrentMap
import java.util.stream.Collectors

class Enums {
companion object {
Expand Down Expand Up @@ -54,12 +52,8 @@ class Enums {
Preconditions.checkArgument(c2.isEnum)
return (c1.enumConstants.size == c2.enumConstants.size &&
Sets.difference(
Arrays.stream(c1.enumConstants)
.map { obj: T1 -> obj!!.name }
.collect(Collectors.toSet()),
Arrays.stream(c2.enumConstants)
.map { obj: T2 -> obj!!.name }
.collect(Collectors.toSet()),
c1.enumConstants.map { obj: T1 -> obj!!.name }.toSet(),
c2.enumConstants.map { obj: T2 -> obj!!.name }.toSet(),
)
.isEmpty())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import java.io.File
import java.io.IOException
import java.net.URI
import java.net.URISyntaxException
import java.util.stream.Collectors
import me.andrz.jackson.JsonContext
import me.andrz.jackson.JsonReferenceException
import me.andrz.jackson.JsonReferenceProcessor
Expand Down Expand Up @@ -89,9 +88,8 @@ class JsonSchemaValidator @VisibleForTesting constructor(private val baseUri: UR

fun validate(schemaJson: JsonNode, objectJson: JsonNode): Set<String> {
return validateInternal(schemaJson, objectJson)
.stream()
.map { obj: ValidationMessage -> obj.message }
.collect(Collectors.toSet())
.toSet()
}

fun getValidationMessageArgs(schemaJson: JsonNode, objectJson: JsonNode): List<Array<String>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import io.airbyte.commons.text.Names
import io.airbyte.protocol.models.SyncMode
import io.airbyte.validation.json.JsonValidationException
import java.util.*
import java.util.stream.Collectors

/**
* Utilities to convert Catalog protocol to Catalog API client. This class was similar to existing
Expand Down Expand Up @@ -76,9 +75,8 @@ object CatalogClientConverters {
// field path.
val selectedFieldNames =
config.selectedFields!!
.stream()
.map { field: SelectedFieldInfo -> field.fieldPath!![0] }
.collect(Collectors.toSet())
.toSet()
// TODO(mfsiega-airbyte): we only check the top level of the cursor/primary key fields
// because we
// don't support filtering nested fields yet.
Expand Down
Loading

0 comments on commit 6a0211a

Please sign in to comment.