Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small convertTo fix #800

Merged
merged 7 commits into from
Aug 1, 2024
1 change: 1 addition & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies {

api(libs.kotlin.datetimeJvm)
implementation(libs.kotlinpoet)
implementation(libs.kotlinLogging)

testImplementation(libs.junit)
testImplementation(libs.kotestAssertions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ internal fun <T> Iterable<T>.anyNull(): Boolean = any { it == null }
internal fun emptyPath(): ColumnPath = ColumnPath(emptyList())

@PublishedApi
internal fun <T : Number> KClass<T>.zero(): T =
internal fun <T : Number> KClass<T>.zeroOrNull(): T? =
when (this) {
Int::class -> 0 as T
Byte::class -> 0.toByte() as T
Expand All @@ -131,10 +131,14 @@ internal fun <T : Number> KClass<T>.zero(): T =
Float::class -> 0.toFloat() as T
BigDecimal::class -> BigDecimal.ZERO as T
BigInteger::class -> BigInteger.ZERO as T
Number::class -> 0 as T
else -> TODO()
Number::class -> 0 as? T
else -> null
}

@PublishedApi
internal fun <T : Number> KClass<T>.zero(): T =
zeroOrNull() ?: throw NotImplementedError("Zero value for $this is not supported")

internal fun <T> catchSilent(body: () -> T): T? =
try {
body()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dataframe.impl.api

import io.github.oshai.kotlinlogging.KotlinLogging
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
Expand All @@ -11,13 +12,13 @@ import org.jetbrains.kotlinx.dataframe.api.ConvertSchemaDsl
import org.jetbrains.kotlinx.dataframe.api.ConverterScope
import org.jetbrains.kotlinx.dataframe.api.ExcessiveColumns
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.all
import org.jetbrains.kotlinx.dataframe.api.allNulls
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.concat
import org.jetbrains.kotlinx.dataframe.api.convertTo
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
import org.jetbrains.kotlinx.dataframe.api.getColumnPaths
import org.jetbrains.kotlinx.dataframe.api.isEmpty
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.api.name
Expand All @@ -29,12 +30,14 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.UnresolvedColumnsPolicy
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.exceptions.ExcessiveColumnsException
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
import org.jetbrains.kotlinx.dataframe.impl.emptyPath
import org.jetbrains.kotlinx.dataframe.impl.schema.createEmptyColumn
import org.jetbrains.kotlinx.dataframe.impl.getColumnPaths
import org.jetbrains.kotlinx.dataframe.impl.schema.createEmptyDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.createNullFilledColumn
import org.jetbrains.kotlinx.dataframe.impl.schema.extractSchema
import org.jetbrains.kotlinx.dataframe.impl.schema.render
import org.jetbrains.kotlinx.dataframe.kind
Expand All @@ -45,6 +48,8 @@ import kotlin.reflect.KType
import kotlin.reflect.full.withNullability
import kotlin.reflect.jvm.jvmErasure

private val logger = KotlinLogging.logger {}

private open class Converter(val transform: ConverterScope.(Any?) -> Any?, val skipNulls: Boolean)

private class Filler(val columns: ColumnsSelector<*, *>, val expr: RowExpression<*, *>)
Expand Down Expand Up @@ -252,22 +257,16 @@ internal fun AnyFrame.convertToImpl(
}
}.toMutableList()

// when the target is nullable but the source does not contain a column, fill it in with nulls / empty dataframes
// when the target is nullable but the source does not contain a column,
// fill it in with nulls / empty dataframes
val size = this.size.nrow
schema.columns.forEach { (name, targetColumn) ->
val isNullable =
// like value column of type Int?
targetColumn.nullable ||
// like value column of type Int? (backup check)
targetColumn.type.isMarkedNullable ||
// like DataRow<Something?> for a group column (all columns in the group will be nullable)
targetColumn.contentType?.isMarkedNullable == true ||
// frame column can be filled with empty dataframes
targetColumn.kind == ColumnKind.Frame

if (name !in visited) {
newColumns += targetColumn.createEmptyColumn(name, size)
if (!isNullable) {
try {
newColumns += targetColumn.createNullFilledColumn(name, size)
} catch (e: IllegalStateException) {
logger.debug(e) { "" }
// if this could not be done automatically, they need to be filled manually
missingPaths.add(path + name)
}
}
Expand All @@ -279,14 +278,39 @@ internal fun AnyFrame.convertToImpl(
val marker = MarkersExtractor.get(clazz)
var result = convertToSchema(marker.schema, emptyPath())

/*
* Here we handle all registered fillers of the user.
* Fillers are registered in the DSL like:
* ```kt
* df.convertTo<Target> {
* fill { col1 and col2 }.with { something }
* fill { col3 }.with { somethingElse }
* }
* ```
* Users can use this to fill up any column that was missing during the conversion.
* They can also fill up and thus overwrite any existing column here.
*/
dsl.fillers.forEach { filler ->
val paths = result.getColumnPaths(filler.columns)
missingPaths.removeAll(paths.toSet())
result = result.update { paths.toColumnSet() }.with {
filler.expr(this, this)
// get all paths from the `fill { col1 and col2 }` part
val paths = result.getColumnPaths(UnresolvedColumnsPolicy.Create, filler.columns).toSet()

// split the paths into those that are already in the df and those that are missing
val (newPaths, existingPaths) = paths.partition { it in missingPaths }

// first fill cols that are already in the df using the `with {}` part of the dsl
result = result.update { existingPaths.toColumnSet() }.with { filler.expr(this, this) }

// then create any missing ones by filling using the `with {}` part of the dsl
result = newPaths.fold(result) { df, newPath ->
df.add(newPath, Infer.Type) { filler.expr(this, this) }
}

// remove the paths that are now filled
missingPaths -= paths
}

// Inform the user which target columns could not be created in the conversion
// The user will need to supply extra information for these, like `fill {}` them.
if (missingPaths.isNotEmpty()) {
throw IllegalArgumentException(
"The following columns were not found in DataFrame: ${
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ internal fun AnyCol.extractSchema(): ColumnSchema =
@PublishedApi
internal fun getSchema(kClass: KClass<*>): DataFrameSchema = MarkersExtractor.get(kClass).schema

/**
* Create "empty" column based on the toplevel of [this] [ColumnSchema].
*/
internal fun ColumnSchema.createEmptyColumn(name: String): AnyCol =
when (this) {
is ColumnSchema.Value -> DataColumn.createValueColumn<Any?>(name, emptyList(), type)
Expand All @@ -110,14 +113,22 @@ internal fun ColumnSchema.createEmptyColumn(name: String): AnyCol =
else -> error("Unexpected ColumnSchema: $this")
}

/** Create "empty" column, filled with either null or empty dataframes. */
internal fun ColumnSchema.createEmptyColumn(name: String, numberOfRows: Int): AnyCol =
/**
* Creates a column based on [this] [ColumnSchema] filled with `null` or empty dataframes.
* @throws IllegalStateException if the column is not nullable and [numberOfRows]` > 0`.
*/
internal fun ColumnSchema.createNullFilledColumn(name: String, numberOfRows: Int): AnyCol =
when (this) {
is ColumnSchema.Value -> DataColumn.createValueColumn(
name = name,
values = List(numberOfRows) { null },
type = type,
)
is ColumnSchema.Value -> {
if (!type.isMarkedNullable && numberOfRows > 0) {
error("Cannot create a null-filled value column of type $type as it's not nullable.")
}
DataColumn.createValueColumn(
name = name,
values = List(numberOfRows) { null },
type = type,
)
}

is ColumnSchema.Group -> DataColumn.createColumnGroup(
name = name,
Expand All @@ -130,7 +141,7 @@ internal fun ColumnSchema.createEmptyColumn(name: String, numberOfRows: Int): An
schema = lazyOf(schema),
)

else -> error("Unexpected ColumnSchema: $this")
else -> error("Cannot create null-filled column of unexpected ColumnSchema: $this")
}

internal fun DataFrameSchema.createEmptyDataFrame(): AnyFrame =
Expand All @@ -143,7 +154,7 @@ internal fun DataFrameSchema.createEmptyDataFrame(numberOfRows: Int): AnyFrame =
DataFrame.empty(numberOfRows)
} else {
columns.map { (name, schema) ->
schema.createEmptyColumn(name, numberOfRows)
schema.createNullFilledColumn(name, numberOfRows)
}.toDataFrame()
}

Expand Down
Loading